17 | module Compiler.Eval
21 | import Compiler.Enzyme.MLIR.Dialect.Ops
22 | import Compiler.LLVM.ADT.APFloat
23 | import Compiler.LLVM.ADT.APInt
24 | import Compiler.LLVM.Support.RawOStream
25 | import Compiler.MLIR.Dialect.Func.IR.FuncOps
26 | import Compiler.MLIR.IR
27 | import Compiler.MLIR.Pass.PassManager
28 | import Compiler.Stablehlo.Dialect.ChloOps
29 | import Compiler.Stablehlo.Dialect.Serialization
30 | import Compiler.Stablehlo.Dialect.StablehloAttrs
31 | import Compiler.Stablehlo.Dialect.StablehloEnums
32 | import Compiler.Stablehlo.Dialect.StablehloOps
33 | import Compiler.Stablehlo.Dialect.Version
34 | import Compiler.Xla.Client.ExecutableBuildOptions
35 | import Compiler.Xla.HLO.Translate.HloToMhlo.HloUtils
36 | import Compiler.Xla.PJRT.C.PjrtCApi
37 | import Compiler.Xla.PJRT.PjrtExecutable
38 | import Compiler.Xla.Shape
39 | import Compiler.Xla.ShapeUtil
42 | import Compiler.LiteralRW
50 | = OutOfBounds Nat Nat
53 | | MlirPassError String
54 | | InvalidHloError String
56 | data BoundSet : Type where
57 | Parameters : Block -> BoundSet
58 | OpLike : {auto iface : Op a} -> a -> BoundSet
62 | show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}"
63 | show (ValueNotFound idx) = "Value not found at index \{show idx}"
64 | show (PjrtErr err) = show err
65 | show (MlirPassError err) = "MlirPassError: \{err}"
66 | show (InvalidHloError err) = "InvalidHloError: \{err}"
69 | ErrIO : Type -> Type
70 | ErrIO = EitherT Err IO
72 | set : IOArray a -> Nat -> a -> ErrIO ()
73 | set cache idx x = do
74 | False <- writeArray cache (cast idx) x | True => right ()
75 | left $
OutOfBounds idx (cast $
max cache)
77 | get : IOArray a -> Nat -> ErrIO a
79 | Nothing <- readArray cache (cast idx) | Just x => right x
80 | let max = cast (max cache)
81 | left $
if idx >= max then OutOfBounds idx max else ValueNotFound idx
83 | itype : MLIRContext -> ValueType -> ErrIO Type_
84 | itype ctx (TensorType shape dtype) = cast <$> RankedTensorType.get shape !(mlirType ctx {dtype})
86 | itypes : MLIRContext -> Vect n ValueType -> ErrIO TypeRange
87 | itypes ctx types = mkTypeRange =<< traverse (itype ctx) (toList types)
89 | 0 Finalizer : Type -> Type
90 | Finalizer a = OpBuilder -> Location -> ValueRange -> ErrIO a
94 | IOArray BoundSet -> MLIRContext -> Location -> Block -> Fn arity -> Finalizer a -> ErrIO ()
95 | interpret cache ctx uloc block f finalizer = do
96 | builder <- atBlockEnd block
97 | set cache f.tag (Parameters block)
98 | ignore $
for f.paramTypes $
\t => addArgument block !(itype ctx t) uloc
99 | traverse_ (\(i, expr) => set cache i !(iop expr)) (toList f.env)
100 | results <- traverse ivalue f.results
101 | ignore $
finalizer builder uloc !(mkValueRange $
toList results)
105 | iop : {auto builder : OpBuilder} -> Op -> ErrIO BoundSet
107 | ivalue : {auto builder : OpBuilder} -> IR.Value -> ErrIO Value.Value
108 | ivalue (V pos op) = iop op >>= \case
109 | Parameters block => cast <$> getArgument block pos
110 | OpLike x {iface} => cast <$> (flip getOpResult pos =<< getOperation x)
112 | iop (BoundSet x) = get cache x
113 | iop (Grad shape f x) = do
114 | revInit <- ivalue $
V 0 $
Lit {shape = [], dtype = F64} 1.0
115 | args <- mkValueRange [!(ivalue x), cast revInit]
116 | retTys <- mkTypeRange [!(itype ctx $
TensorType shape F64)]
117 | op <- AutoDiffRegionOp.create builder ctx uloc retTys args EnzymeActive EnzymeActivenoneed
118 | body <- emplaceBlock $
getBody op
119 | interpret cache ctx uloc body f YieldOp.create
121 | iop (MinValue dtype) = do
123 | if isSigned {dtype}
124 | then getSignedMinValue (numBits {dtype})
125 | else getMinValue (numBits {dtype})
126 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype})
127 | attr <- APInt.get type apInt
128 | OpLike <$> ConstantOp.create builder uloc attr
129 | iop (MaxValue dtype) = do
131 | if isSigned {dtype}
132 | then getSignedMaxValue (numBits {dtype})
133 | else getMaxValue (numBits {dtype})
134 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype})
135 | attr <- APInt.get type apInt
136 | OpLike <$> ConstantOp.create builder uloc attr
137 | iop MinFiniteFloat = do
138 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype = F64})
139 | attr <- APFloat.get type !(getLargest True)
140 | OpLike <$> ConstantOp.create builder uloc attr
141 | iop MaxFiniteFloat = do
142 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype = F64})
143 | attr <- APFloat.get type !(getLargest False)
144 | OpLike <$> ConstantOp.create builder uloc attr
145 | iop (Lit {shape, dtype} lit) = do
146 | attr <- createDenseElementsAttrFromLiteral !(write {dtype} lit) builder
147 | OpLike <$> ConstantOp.create builder uloc attr
148 | iop (Broadcast {dtype} from to x) =
149 | if elem 0 to && from /= to
151 | shape <- mkShape {dtype} to
152 | literal <- allocLiteral shape
153 | attr <- createDenseElementsAttrFromLiteral literal builder
154 | OpLike <$> ConstantOp.create builder uloc attr
156 | let broadcastDims = Prelude.map (+ length to `minus` length from) $
List.range $
length from
158 | resTy <- itype ctx $
TensorType to dtype
159 | OpLike <$> BroadcastInDimOp.create builder uloc resTy !(ivalue x) broadcastDims
160 | iop (UnaryElementwise f x) = do
161 | let W
mkop @{iface} = f.create
162 | OpLike <$> mkop builder uloc !(ivalue x)
166 | data Wrap : Type where
167 | W : forall a . (OpBuilder -> Location -> Value.Value -> ErrIO a) -> Op a => Wrap
169 | (.create) : UnaryOp -> Wrap
171 | Abs => W AbsOp.create
172 | Ceil => W CeilOp.create
173 | Cos => W CosineOp.create
174 | Exp => W ExpOp.create
175 | Floor => W FloorOp.create
176 | Log => W LogOp.create
177 | Logistic => W LogisticOp.create
178 | Not => W NotOp.create
179 | Neg => W NegOp.create
180 | Sin => W SineOp.create
181 | Sqrt => W SqrtOp.create
182 | Tan => W TanOp.create
183 | Tanh => W TanhOp.create
185 | Acos => W AcosOp.create
186 | Acosh => W AcoshOp.create
187 | Asin => W AsinOp.create
188 | Asinh => W AsinhOp.create
189 | Atan => W AtanOp.create
190 | Atanh => W AtanhOp.create
191 | Cosh => W CoshOp.create
192 | Sinh => W SinhOp.create
193 | Erf => W ErfOp.create
194 | ErfInv => W ErfInvOp.create
195 | Square => W SquareOp.create
196 | iop (Convert {dtype} resultShape x) = do
197 | resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx {dtype})
198 | OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
199 | iop (BitCastConvert {dtype} resultShape x) = do
200 | resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx {dtype})
201 | OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
202 | iop (BinaryElementwise f lhs rhs) = do
203 | let W
mkop @{iface} = f.create
204 | OpLike <$> mkop builder uloc !(ivalue lhs) !(ivalue rhs)
208 | data Wrap : Type where
209 | W : forall a . (OpBuilder -> Location -> Value.Value -> Value.Value -> ErrIO a) -> Op a => Wrap
211 | (.create) : BinaryOp -> Wrap
213 | Compare direction => W (\b, l, x, y => CompareOp.create b l x y (cast direction))
214 | Add => W AddOp.create
215 | Div => W DivOp.create
216 | Max => W MaxOp.create
217 | Min => W MinOp.create
218 | Mul => W MulOp.create
219 | Pow => W PowOp.create
220 | Rem => W RemOp.create
221 | Sub => W SubtractOp.create
222 | And => W AndOp.create
223 | Or => W OrOp.create
224 | ShiftRightLogical => W ShiftRightLogicalOp.create
225 | iop (If resTy pred true false) = do
226 | op <- IfOp.create builder uloc !(itype ctx resTy) !(ivalue pred)
227 | bodyT <- emplaceBlock $
getTrueBranch op
228 | bodyF <- emplaceBlock $
getFalseBranch op
229 | interpret cache ctx uloc bodyT true StablehloOps.ReturnOp.create
230 | interpret cache ctx uloc bodyF false StablehloOps.ReturnOp.create
232 | iop (While cond body inits) = do
233 | inits <- mkValueRange =<< traverse ivalue (toList inits)
234 | op <- WhileOp.create builder uloc inits
235 | cond' <- emplaceBlock $
getCond op
236 | interpret cache ctx uloc cond' cond StablehloOps.ReturnOp.create
237 | body' <- emplaceBlock $
getBody op
238 | interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
240 | iop (Reduce body inits axes xs) = do
241 | inits <- mkValueRange !(traverse ivalue $
toList inits)
242 | xs <- mkValueRange !(traverse ivalue $
toList xs)
243 | op <- ReduceOp.create builder uloc xs inits axes
244 | body' <- emplaceBlock $
getBody op
245 | interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
247 | iop (Slice starts stops strides x) = do
248 | OpLike <$> SliceOp.create builder uloc !(ivalue x) starts stops strides
249 | iop (DynamicSlice starts sizes x) = do
250 | starts <- mkValueRange !(traverse ivalue starts)
251 | OpLike <$> DynamicSliceOp.create builder uloc !(ivalue x) starts sizes
252 | iop (Cholesky x) = OpLike <$> CholeskyOp.create builder uloc !(ivalue x) True
253 | iop (Concat axis xs) = do
254 | xs <- mkValueRange =<< traverse ivalue (toList xs)
255 | OpLike <$> ConcatenateOp.create builder uloc xs axis
256 | iop (Iota {dtype} shape dim) = do
257 | resultType <- cast <$> RankedTensorType.get shape !(mlirType ctx {dtype})
258 | OpLike <$> IotaOp.create builder uloc resultType dim
259 | iop (DotGeneral lb rb lc rc resultType lhs rhs) = do
260 | ddn <- DotDimensionNumbersAttr.get ctx lb rb lc rc
261 | resTy <- itype ctx resultType
262 | OpLike <$> DotGeneralOp.create builder uloc resTy !(ivalue lhs) !(ivalue rhs) ddn
263 | iop (Map f xs resTy dims) = do
264 | xs <- mkValueRange =<< traverse ivalue (toList xs)
265 | op <- MapOp.create builder uloc !(itype ctx resTy) xs dims
266 | f' <- emplaceBlock $
getComputation op
267 | interpret cache ctx uloc f' f StablehloOps.ReturnOp.create
269 | iop (Reshape {dtype} to x) = do
270 | resultType <- cast <$> RankedTensorType.get to !(mlirType ctx {dtype})
271 | OpLike <$> ReshapeOp.create builder uloc resultType !(ivalue x)
272 | iop (Select pred true false) = do
273 | OpLike <$> SelectOp.create builder uloc !(ivalue pred) !(ivalue true) !(ivalue false)
274 | iop (Sort comp axis isStable x) = do
275 | op <- SortOp.create builder uloc !(ivalue x) axis isStable
276 | comp' <- emplaceBlock $
getComparator op
277 | comparator <- interpret cache ctx uloc comp' comp StablehloOps.ReturnOp.create
279 | iop (Reverse axes x) = OpLike <$> ReverseOp.create builder uloc !(ivalue x) axes
280 | iop (Transpose ordering x) = OpLike <$> TransposeOp.create builder uloc !(ivalue x) ordering
281 | iop (TriangularSolve a b lower) =
282 | OpLike <$> TriangularSolveOp.create
283 | builder uloc !(ivalue a) !(ivalue b) True lower False NoTranspose
284 | iop (Rng state resultType) = do
285 | stateTy <- itype ctx $
TensorType [2] U64
286 | OpLike <$> RngBitGeneratorOp.create
287 | builder uloc stateTy !(itype ctx resultType) ThreeFry !(ivalue state)
291 | execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $
Vect outputs Literal
292 | execute (MkDevice mlirCtx passManager api client) f shapes = do
293 | uloc <- UnknownLoc.get mlirCtx
294 | cache <- newArray $
cast $
counter f.env
295 | moduleOp <- ModuleOp.create uloc "root"
296 | fnType <- FunctionType.get mlirCtx !(itypes mlirCtx f.paramTypes) !(itypes mlirCtx f.resultTypes)
297 | fn <- FuncOp.create uloc "main" fnType
298 | interpret cache mlirCtx uloc !(addEntryBlock fn) f FuncOps.ReturnOp.create
299 | pushBack moduleOp !(getOperation fn)
301 | True <- run passManager !(getOperation moduleOp)
302 | | _ => throwE $
MlirPassError "Failed to run MLIR passes"
304 | version <- toString !getCurrentVersion
305 | True <- serializePortableArtifact moduleOp version !(rawStringOStream code)
306 | | False => throwE $
InvalidHloError "Failed to serialize MLIR for version \{c_str version}"
307 | bimapEitherT PjrtErr id $
do
308 | executableBuildOptions <- mkExecutableBuildOptions
309 | compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions)
310 | program <- mkPjrtProgram code
311 | loadedExec <- pjrtClientCompile api client program compileOptions
314 | delete compileOptions
315 | delete executableBuildOptions
317 | buffers <- pjrtLoadedExecutableExecute api loadedExec outputs
318 | pjrtLoadedExecutableDestroy api loadedExec
320 | for (zip buffers shapes) $
\(buffer, shape) => do
321 | literal <- allocLiteral shape
326 | event <- pjrtBufferToHostBuffer api buffer literal
327 | pjrtEventAwait api event
328 | pjrtEventDestroy api event
329 | pjrtBufferDestroy api buffer