19 | import Control.Monad.State
20 | import Data.Primitives.Interpolation
21 | import public Compiler.Stablehlo.Dialect.StablehloEnums
23 | import Derive.Prelude
24 | import Language.Reflection
26 | import Compiler.LiteralRW
27 | import Compiler.Xla.XlaData
33 | %language ElabReflection
35 | Show a => Interpolation (List a) where
39 | data ValueType : Type where
40 | TensorType : Shape -> (0 dtype : Type) -> Primitive dtype => ValueType
43 | Show ValueType where
44 | show (TensorType shape dtype) = "\{shape} \{repr {dtype}}"
50 | data Value = V Nat Op
56 | data Env = MkEnv Nat (List (Nat, Op))
63 | emptyFrom : Env -> Env
64 | emptyFrom (MkEnv n _) = MkEnv n []
67 | updateCounterFrom : Env -> State Env ()
68 | updateCounterFrom (MkEnv n _) = do
73 | toList : Env -> List (Nat, Op)
74 | toList (MkEnv _ env) = reverse env
77 | counter : Env -> Nat
78 | counter (MkEnv c _) = c
81 | record Fn (arity : Nat) where
84 | paramTypes : Vect arity ValueType
85 | resultTypes : Vect resultCount ValueType
86 | results : Vect resultCount Value
91 | Compare ComparisonDirection | And | Or | Add | Sub | Mul | Div | Rem | Pow | Min | Max
94 | %runElab derive "ComparisonDirection" [Show]
95 | %runElab derive "BinaryOp" [Show]
99 | Not | Neg | Ceil | Floor | Abs | Log | Exp | Logistic | Sqrt | Sin | Cos | Tan | Tanh
100 | | Erf | ErfInv | Square | Asin | Acos | Atan | Sinh | Cosh | Asinh | Acosh | Atanh
102 | %runElab derive "UnaryOp" [Show]
105 | data Op : Type where
106 | BoundSet : Nat -> Op
108 | Lit : PrimitiveRW dtype ty => {shape : _} -> Literal shape ty -> Op
109 | Grad : Shape -> Fn 1 -> Value -> Op
110 | MinValue : (0 dtype : _) -> Primitive dtype => Op
111 | MaxValue : (0 dtype : _) -> Primitive dtype => Op
112 | MinFiniteFloat : Op
113 | MaxFiniteFloat : Op
114 | Iota : Primitive dtype => (shape : Shape) -> (axis : Nat) -> Op
115 | BitCastConvert : Primitive dtype => Shape -> Value -> Op
116 | Convert : Primitive dtype => Shape -> Value -> Op
117 | Reshape : Primitive dtype => Shape -> Value -> Op
118 | Slice : (starts, stops, strides : List Nat) -> Value -> Op
119 | DynamicSlice : (starts : List Value) -> (sizes : List Nat) -> Value -> Op
120 | Concat : (axis : Nat) -> Vect (S n) Value -> Op
121 | Transpose : (ordering : List Nat) -> Value -> Op
122 | Broadcast : Primitive dtype => (from, to : Shape) -> Value -> Op
123 | Map : Fn arity -> Vect arity Value -> (resultType : ValueType) -> Shape -> Op
124 | Reduce : Fn (n + n) -> (inits : Vect n Value) -> (axes : List Nat) -> Vect n Value -> Op
125 | Sort : Fn 2 -> (axis : Nat) -> (isStable : Bool) -> Value -> Op
126 | Reverse : (axes : List Nat) -> Value -> Op
127 | BinaryElementwise : BinaryOp -> Value -> Value -> Op
128 | UnaryElementwise : UnaryOp -> Value -> Op
129 | Select : (predicate, onTrue, onFalse : Value) -> Op
130 | While : (condition, body : Fn n) -> (init : Vect n Value) -> Op
131 | If : (resultType : ValueType) -> (predicate : Value) -> (onTrue, onFalse : Fn 0) -> Op
133 | (lBatch, lContract, rBatch, rContract: List Nat) ->
134 | (resultType : ValueType) ->
138 | Cholesky : Value -> Op
139 | TriangularSolve : Value -> Value -> (isLower : Bool) -> Op
140 | Rng : (state : Value) -> (resultType : ValueType) -> Op
143 | tagOp : Monad m => Op -> StateT Env m Op
145 | MkEnv next env <- get
146 | put $
MkEnv (S next) ((next, expr) :: env)
147 | pure (BoundSet next)
150 | reserve : State Env Nat
152 | MkEnv next env <- get
153 | put $
MkEnv (S next) env
157 | showOp : Nat -> Op -> String
160 | showValue : Nat -> Value -> String
161 | showValue indent (V idx op) = "(\{showOp indent op}):\{show idx}"
164 | showValueList : Nat -> List Value -> String
165 | showValueList indent xs = "[" ++ joinBy ", " (toList $
map (showValue indent) xs) ++ "]"
168 | showEnv : Nat -> Env -> String
169 | showEnv indent (MkEnv max env) = joinBy "\n" $
assert_total $
map fmt (reverse env)
173 | fmt : (Nat, Op) -> String
175 | let sep = replicate (4 + length (show max) `minus` length (show n)) ' '
176 | in "\{replicate indent ' '}\{show n}\{sep}\{showOp indent x}"
179 | showFn : Nat -> Fn arity -> String
180 | showFn indent (MkFn parameterSetTag paramTypes resultTypes results env@(MkEnv _ env')) =
181 | let params = "\{show parameterSetTag} \{show paramTypes}"
182 | res = "\{showValueList (indent + 2) $ toList results}" in
184 | [] => "\{params} => \{res}"
186 | "\{params} => \{res} with vars {\n\{showEnv (indent + 4) env}\n\{replicate (indent + 2) ' '}}"
188 | export Show (Fn arity) where show = assert_total $
showFn 0
190 | showOp indent (Lit {shape, dtype} x) = "Lit \{shape} \{repr {dtype}}"
191 | showOp indent (BoundSet k) = "Bound \{k}"
192 | showOp indent (Grad _ op x) = "Grad {op = \{showFn indent op}} \{showValue indent x}"
193 | showOp index (MinValue dtype) = "MinValue \{repr {dtype}}"
194 | showOp index (MaxValue dtype) = "MaxValue \{repr {dtype}}"
195 | showOp index MinFiniteFloat = "MinFiniteFloat"
196 | showOp index MaxFiniteFloat = "MaxFiniteFloat"
197 | showOp indent (Iota {dtype} shape axis) =
198 | "Iota {shape = \{show shape}, dtype = \{repr {dtype}}, axis = \{axis}}"
199 | showOp indent (Convert {dtype} shape x) =
200 | "Convert {dtype = \{repr {dtype}}} \{showValue indent x}"
201 | showOp indent (BitCastConvert {dtype} shape x) =
202 | "BitCastConvert {dtype = \{repr {dtype}}} \{showValue indent x}"
203 | showOp indent (Reshape to x) = "Reshape {to = \{to}} \{showValue indent x}"
204 | showOp indent (Slice starts stops strides x) =
205 | "Slice {starts = \{starts}, stops = \{stops}, strides = \{strides}} \{showValue indent x}"
206 | showOp indent (DynamicSlice starts sizes x) =
207 | "DynamicSlice {starts = \{showValueList indent starts}, sizes = \{sizes}} \{showValue indent x}"
208 | showOp indent (Concat axis xs) = "Concat {axis = \{axis}} \{showValueList indent $ toList xs}"
209 | showOp indent (Transpose ordering x) = "Transpose {ordering = \{ordering}} \{showValue indent x}"
210 | showOp indent (Broadcast from to x) =
211 | "Broadcast {from = \{from}, to = \{to}} \{showValue indent x}"
212 | showOp indent (Map f xs _ _) = "Map {f = \{showFn indent f}} \{showValueList indent $ toList xs}"
213 | showOp indent (Reduce op neutrals axes xs) =
214 | "Reduce {op = \{showFn indent op}, inits = \{showValueList indent $ toList neutrals}," ++
215 | " axes = \{axes}} \{showValueList indent $ toList xs}"
216 | showOp indent (Sort f axis _ xs) =
217 | "Sort {f = \{showFn indent f}, axis = \{axis}} \{showValue indent xs}"
218 | showOp indent (Reverse axes x) = "Reverse \{axes} \{showValue indent x}"
219 | showOp indent (BinaryElementwise op x y) = "\{show op} \{showValue indent x} \{showValue indent y}"
220 | showOp indent (UnaryElementwise op x) = "\{show op} \{showValue indent x}"
221 | showOp indent (Select p t f) =
222 | "Select {predicate = \{showValue indent p}, onTrue = \{showValue indent t}," ++
223 | " onFalse = \{showValue indent f}}"
224 | showOp indent (While c b is) =
225 | "While {condition = \{showFn indent c}, body = \{showFn indent b}," ++
226 | " initials = \{showValueList indent $ toList is}}"
227 | showOp indent (If _ p ft ff) =
228 | "If {predicate = \{showValue indent p}, onTrue = \{showFn indent ft}," ++
229 | " onFalse = \{showFn indent ff}}"
230 | showOp indent (DotGeneral lBatch lContract rBatch rContract _ x y) =
231 | "DotGeneral {lBatch = \{lBatch}, lContract = \{lContract}," ++
232 | " rBatch = \{rBatch}, rContract = \{rContract}} \{showValue indent x} \{showValue indent y}"
233 | showOp indent (Cholesky x) = "Cholesky \{showValue indent x}"
234 | showOp indent (TriangularSolve x y isLower) =
235 | "TriangularSolve {isLower = \{show isLower}} \{showValue indent x} \{showValue indent y}"
236 | showOp indent (Rng state shape) = "Rng {state = \{showValue indent state}, shape = \{show shape}}"
238 | export Show Op where show = assert_total $
showOp 0
242 | show (V pos x) = "V \{show pos} \{show x}"