19 | import Control.Monad.State
20 | import Data.Primitives.Interpolation
21 | import public Compiler.Stablehlo.Dialect.StablehloEnums
23 | import Derive.Prelude
24 | import Language.Reflection
26 | import Compiler.LiteralRW
27 | import Compiler.Xla.XlaData
33 | %language ElabReflection
35 | Show a => Interpolation (List a) where
39 | data ValueType = TensorType Shape DType
41 | Show ValueType where
42 | show (TensorType shape dtype) = "\{show shape} \{show dtype}"
48 | data Value = V Nat Op
54 | data Env = MkEnv Nat (List (Nat, Op))
61 | emptyFrom : Env -> Env
62 | emptyFrom (MkEnv n _) = MkEnv n []
65 | updateCounterFrom : Env -> State Env ()
66 | updateCounterFrom (MkEnv n _) = do
71 | toList : Env -> List (Nat, Op)
72 | toList (MkEnv _ env) = reverse env
75 | counter : Env -> Nat
76 | counter (MkEnv c _) = c
79 | record Fn (arity : Nat) where
82 | paramTypes : Vect arity ValueType
83 | resultTypes : Vect resultCount ValueType
84 | results : Vect resultCount Value
89 | Compare ComparisonDirection | And | Or | Add | Sub | Mul | Div | Rem | Pow | Min | Max
92 | %runElab derive "ComparisonDirection" [Show]
93 | %runElab derive "BinaryOp" [Show]
97 | Not | Neg | Ceil | Floor | Abs | Log | Exp | Logistic | Sqrt | Sin | Cos | Tan | Tanh
98 | | Erf | ErfInv | Square | Asin | Acos | Atan | Sinh | Cosh | Asinh | Acosh | Atanh
100 | %runElab derive "UnaryOp" [Show]
103 | data Op : Type where
104 | BoundSet : Nat -> Op
106 | Lit : (shape : Shape) -> (dtype : DType) -> Literal shape (idrisType dtype) -> Op
107 | Grad : Shape -> Fn 1 -> Value -> Op
108 | MinValue : DType -> Op
109 | MaxValue : DType -> Op
110 | MinFiniteFloat : Op
111 | MaxFiniteFloat : Op
112 | Iota : (shape : Shape) -> DType -> (axis : Nat) -> Op
113 | BitCastConvert : DType -> Shape -> Value -> Op
114 | Convert : DType -> Shape -> Value -> Op
115 | Reshape : DType -> Shape -> Value -> Op
116 | Slice : (starts, stops, strides : List Nat) -> Value -> Op
117 | DynamicSlice : (starts : List Value) -> (sizes : List Nat) -> Value -> Op
118 | Concat : (axis : Nat) -> Vect (S n) Value -> Op
119 | Transpose : (ordering : List Nat) -> Value -> Op
120 | Broadcast : DType -> (from, to : Shape) -> Value -> Op
121 | Map : Fn arity -> Vect arity Value -> (resultType : ValueType) -> Shape -> Op
122 | Reduce : Fn (n + n) -> (inits : Vect n Value) -> (axes : List Nat) -> Vect n Value -> Op
123 | Sort : Fn 2 -> (axis : Nat) -> (isStable : Bool) -> Value -> Op
124 | Reverse : (axes : List Nat) -> Value -> Op
125 | BinaryElementwise : BinaryOp -> Value -> Value -> Op
126 | UnaryElementwise : UnaryOp -> Value -> Op
127 | Select : (predicate, onTrue, onFalse : Value) -> Op
128 | While : (condition, body : Fn n) -> (init : Vect n Value) -> Op
129 | If : (resultType : ValueType) -> (predicate : Value) -> (onTrue, onFalse : Fn 0) -> Op
131 | (lBatch, lContract, rBatch, rContract: List Nat) ->
132 | (resultType : ValueType) ->
136 | Cholesky : Value -> Op
137 | TriangularSolve : Value -> Value -> (isLower : Bool) -> Op
138 | Rng : (state : Value) -> (resultType : ValueType) -> Op
141 | tagOp : Monad m => Op -> StateT Env m Op
143 | MkEnv next env <- get
144 | put $
MkEnv (S next) ((next, expr) :: env)
145 | pure (BoundSet next)
148 | reserve : State Env Nat
150 | MkEnv next env <- get
151 | put $
MkEnv (S next) env
155 | showOp : Nat -> Op -> String
158 | showValue : Nat -> Value -> String
159 | showValue indent (V idx op) = "(\{showOp indent op}):\{show idx}"
162 | showValueList : Nat -> List Value -> String
163 | showValueList indent xs = "[" ++ joinBy ", " (toList $
map (showValue indent) xs) ++ "]"
166 | showEnv : Nat -> Env -> String
167 | showEnv indent (MkEnv max env) = joinBy "\n" $
assert_total $
map fmt (reverse env)
171 | fmt : (Nat, Op) -> String
173 | let sep = replicate (4 + length (show max) `minus` length (show n)) ' '
174 | in "\{replicate indent ' '}\{show n}\{sep}\{showOp indent x}"
177 | showFn : Nat -> Fn arity -> String
178 | showFn indent (MkFn parameterSetTag paramTypes resultTypes results env@(MkEnv _ env')) =
179 | let params = "\{show parameterSetTag} \{show paramTypes}"
180 | res = "\{showValueList (indent + 2) $ toList results}" in
182 | [] => "\{params} => \{res}"
184 | "\{params} => \{res} with vars {\n\{showEnv (indent + 4) env}\n\{replicate (indent + 2) ' '}}"
186 | export Show (Fn arity) where show = assert_total $
showFn 0
188 | showOp indent (Lit shape dtype x) = "Lit \{shape} \{show dtype}"
189 | showOp indent (BoundSet k) = "Bound \{k}"
190 | showOp indent (Grad _ op x) = "Grad {op = \{showFn indent op}} \{showValue indent x}"
191 | showOp index (MinValue dtype) = "MinValue \{show dtype}"
192 | showOp index (MaxValue dtype) = "MaxValue \{show dtype}"
193 | showOp index MinFiniteFloat = "MinFiniteFloat"
194 | showOp index MaxFiniteFloat = "MaxFiniteFloat"
195 | showOp indent (Iota dtype shape axis) =
196 | "Iota {shape = \{show shape}, dtype = \{show dtype}, axis = \{axis}}"
197 | showOp indent (Convert dtype shape x) =
198 | "Convert {dtype = \{show dtype}} \{showValue indent x}"
199 | showOp indent (BitCastConvert dtype shape x) =
200 | "BitCastConvert {dtype = \{show dtype}} \{showValue indent x}"
201 | showOp indent (Reshape _ to x) = "Reshape {to = \{to}} \{showValue indent x}"
202 | showOp indent (Slice starts stops strides x) =
203 | "Slice {starts = \{starts}, stops = \{stops}, strides = \{strides}} \{showValue indent x}"
204 | showOp indent (DynamicSlice starts sizes x) =
205 | "DynamicSlice {starts = \{showValueList indent starts}, sizes = \{sizes}} \{showValue indent x}"
206 | showOp indent (Concat axis xs) = "Concat {axis = \{axis}} \{showValueList indent $ toList xs}"
207 | showOp indent (Transpose ordering x) = "Transpose {ordering = \{ordering}} \{showValue indent x}"
208 | showOp indent (Broadcast _ from to x) =
209 | "Broadcast {from = \{from}, to = \{to}} \{showValue indent x}"
210 | showOp indent (Map f xs _ _) = "Map {f = \{showFn indent f}} \{showValueList indent $ toList xs}"
211 | showOp indent (Reduce op neutrals axes xs) =
212 | "Reduce {op = \{showFn indent op}, inits = \{showValueList indent $ toList neutrals}," ++
213 | " axes = \{axes}} \{showValueList indent $ toList xs}"
214 | showOp indent (Sort f axis _ xs) =
215 | "Sort {f = \{showFn indent f}, axis = \{axis}} \{showValue indent xs}"
216 | showOp indent (Reverse axes x) = "Reverse \{axes} \{showValue indent x}"
217 | showOp indent (BinaryElementwise op x y) = "\{show op} \{showValue indent x} \{showValue indent y}"
218 | showOp indent (UnaryElementwise op x) = "\{show op} \{showValue indent x}"
219 | showOp indent (Select p t f) =
220 | "Select {predicate = \{showValue indent p}, onTrue = \{showValue indent t}," ++
221 | " onFalse = \{showValue indent f}}"
222 | showOp indent (While c b is) =
223 | "While {condition = \{showFn indent c}, body = \{showFn indent b}," ++
224 | " initials = \{showValueList indent $ toList is}}"
225 | showOp indent (If _ p ft ff) =
226 | "If {predicate = \{showValue indent p}, onTrue = \{showFn indent ft}," ++
227 | " onFalse = \{showFn indent ff}}"
228 | showOp indent (DotGeneral lBatch lContract rBatch rContract _ x y) =
229 | "DotGeneral {lBatch = \{lBatch}, lContract = \{lContract}," ++
230 | " rBatch = \{rBatch}, rContract = \{rContract}} \{showValue indent x} \{showValue indent y}"
231 | showOp indent (Cholesky x) = "Cholesky \{showValue indent x}"
232 | showOp indent (TriangularSolve x y isLower) =
233 | "TriangularSolve {isLower = \{show isLower}} \{showValue indent x} \{showValue indent y}"
234 | showOp indent (Rng state shape) = "Rng {state = \{showValue indent state}, shape = \{show shape}}"
236 | export Show Op where show = assert_total $
showOp 0
240 | show (V pos x) = "V \{show pos} \{show x}"