17 | module Compiler.Eval
21 | import Compiler.Enzyme.MLIR.Dialect.Ops
22 | import Compiler.LLVM.ADT.APFloat
23 | import Compiler.LLVM.ADT.APInt
24 | import Compiler.LLVM.Support.RawOStream
25 | import Compiler.MLIR.Dialect.Func.IR.FuncOps
26 | import Compiler.MLIR.IR
27 | import Compiler.MLIR.Pass.PassManager
28 | import Compiler.Stablehlo.Dialect.ChloOps
29 | import Compiler.Stablehlo.Dialect.Serialization
30 | import Compiler.Stablehlo.Dialect.StablehloAttrs
31 | import Compiler.Stablehlo.Dialect.StablehloEnums
32 | import Compiler.Stablehlo.Dialect.StablehloOps
33 | import Compiler.Stablehlo.Dialect.Version
34 | import Compiler.Xla.Client.ExecutableBuildOptions
35 | import Compiler.Xla.HLO.Translate.HloToMhlo.HloUtils
36 | import Compiler.Xla.PJRT.C.PjrtCApi
37 | import Compiler.Xla.PJRT.PjrtExecutable
38 | import Compiler.Xla.Shape
39 | import Compiler.Xla.ShapeUtil
40 | import Compiler.Xla.XlaData
41 | import Compiler.DType
44 | import Compiler.LiteralRW
52 | = OutOfBounds Nat Nat
55 | | MlirPassError String
56 | | InvalidHloError String
58 | data BoundSet : Type where
59 | Parameters : Block -> BoundSet
60 | OpLike : {auto iface : Op a} -> a -> BoundSet
64 | show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}"
65 | show (ValueNotFound idx) = "Value not found at index \{show idx}"
66 | show (PjrtErr err) = show err
67 | show (MlirPassError err) = "MlirPassError: \{err}"
68 | show (InvalidHloError err) = "InvalidHloError: \{err}"
71 | ErrIO : Type -> Type
72 | ErrIO = EitherT Err IO
74 | set : IOArray a -> Nat -> a -> ErrIO ()
75 | set cache idx x = do
76 | False <- writeArray cache (cast idx) x | True => right ()
77 | left $
OutOfBounds idx (cast $
max cache)
79 | get : IOArray a -> Nat -> ErrIO a
81 | Nothing <- readArray cache (cast idx) | Just x => right x
82 | let max = cast (max cache)
83 | left $
if idx >= max then OutOfBounds idx max else ValueNotFound idx
85 | itype : MLIRContext -> ValueType -> ErrIO Type_
86 | itype ctx (TensorType shape dtype) = cast <$> RankedTensorType.get shape !(mlirType ctx dtype)
88 | itypes : MLIRContext -> Vect n ValueType -> ErrIO TypeRange
89 | itypes ctx types = mkTypeRange =<< traverse (itype ctx) (toList types)
91 | 0 Finalizer : Type -> Type
92 | Finalizer a = OpBuilder -> Location -> ValueRange -> ErrIO a
96 | IOArray BoundSet -> MLIRContext -> Location -> Block -> Fn arity -> Finalizer a -> ErrIO ()
97 | interpret cache ctx uloc block f finalizer = do
98 | builder <- atBlockEnd block
99 | set cache f.tag (Parameters block)
100 | ignore $
for f.paramTypes $
\t => addArgument block !(itype ctx t) uloc
101 | traverse_ (\(i, expr) => set cache i !(iop expr)) (toList f.env)
102 | results <- traverse ivalue f.results
103 | ignore $
finalizer builder uloc !(mkValueRange $
toList results)
107 | iop : {auto builder : OpBuilder} -> Op -> ErrIO BoundSet
109 | ivalue : {auto builder : OpBuilder} -> IR.Value -> ErrIO Value.Value
110 | ivalue (V pos op) = iop op >>= \case
111 | Parameters block => cast <$> getArgument block pos
112 | OpLike x {iface} => cast <$> (flip getOpResult pos =<< getOperation x)
114 | iop (BoundSet x) = get cache x
115 | iop (Grad shape f x) = do
116 | revInit <- ivalue $
V 0 $
Lit [] F64 1.0
117 | args <- mkValueRange [!(ivalue x), cast revInit]
118 | retTys <- mkTypeRange [!(itype ctx $
TensorType shape F64)]
119 | op <- AutoDiffRegionOp.create builder ctx uloc retTys args EnzymeActive EnzymeActivenoneed
120 | body <- emplaceBlock $
getBody op
121 | interpret cache ctx uloc body f YieldOp.create
123 | iop (MinValue dtype) = do
126 | then getSignedMinValue (numBits dtype)
127 | else getMinValue (numBits dtype)
128 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx dtype)
129 | attr <- APInt.get type apInt
130 | OpLike <$> ConstantOp.create builder uloc attr
131 | iop (MaxValue dtype) = do
134 | then getSignedMaxValue (numBits dtype)
135 | else getMaxValue (numBits dtype)
136 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx dtype)
137 | attr <- APInt.get type apInt
138 | OpLike <$> ConstantOp.create builder uloc attr
139 | iop MinFiniteFloat = do
140 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx F64)
141 | attr <- APFloat.get type !(getLargest True)
142 | OpLike <$> ConstantOp.create builder uloc attr
143 | iop MaxFiniteFloat = do
144 | type <- cast <$> RankedTensorType.get [] !(mlirType ctx F64)
145 | attr <- APFloat.get type !(getLargest False)
146 | OpLike <$> ConstantOp.create builder uloc attr
147 | iop (Lit shape dtype lit) = do
148 | attr <- createDenseElementsAttrFromLiteral !(write dtype lit) builder
149 | OpLike <$> ConstantOp.create builder uloc attr
150 | iop (Broadcast dtype from to x) =
151 | if elem 0 to && from /= to
153 | shape <- mkShape to dtype
154 | literal <- allocLiteral shape
155 | attr <- createDenseElementsAttrFromLiteral literal builder
156 | OpLike <$> ConstantOp.create builder uloc attr
158 | let broadcastDims = Prelude.map (+ length to `minus` length from) $
List.range $
length from
160 | resTy <- itype ctx $
TensorType to dtype
161 | OpLike <$> BroadcastInDimOp.create builder uloc resTy !(ivalue x) broadcastDims
162 | iop (UnaryElementwise f x) = do
163 | let W
mkop @{iface} = f.create
164 | OpLike <$> mkop builder uloc !(ivalue x)
168 | data Wrap : Type where
169 | W : forall a . (OpBuilder -> Location -> Value.Value -> ErrIO a) -> Op a => Wrap
171 | (.create) : UnaryOp -> Wrap
173 | Abs => W AbsOp.create
174 | Ceil => W CeilOp.create
175 | Cos => W CosineOp.create
176 | Exp => W ExpOp.create
177 | Floor => W FloorOp.create
178 | Log => W LogOp.create
179 | Logistic => W LogisticOp.create
180 | Not => W NotOp.create
181 | Neg => W NegOp.create
182 | Sin => W SineOp.create
183 | Sqrt => W SqrtOp.create
184 | Tan => W TanOp.create
185 | Tanh => W TanhOp.create
187 | Acos => W AcosOp.create
188 | Acosh => W AcoshOp.create
189 | Asin => W AsinOp.create
190 | Asinh => W AsinhOp.create
191 | Atan => W AtanOp.create
192 | Atanh => W AtanhOp.create
193 | Cosh => W CoshOp.create
194 | Sinh => W SinhOp.create
195 | Erf => W ErfOp.create
196 | ErfInv => W ErfInvOp.create
197 | Square => W SquareOp.create
198 | iop (Convert dtype resultShape x) = do
199 | resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx dtype)
200 | OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
201 | iop (BitCastConvert dtype resultShape x) = do
202 | resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx dtype)
203 | OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
204 | iop (BinaryElementwise f lhs rhs) = do
205 | let W
mkop @{iface} = f.create
206 | OpLike <$> mkop builder uloc !(ivalue lhs) !(ivalue rhs)
210 | data Wrap : Type where
211 | W : forall a . (OpBuilder -> Location -> Value.Value -> Value.Value -> ErrIO a) -> Op a => Wrap
213 | (.create) : BinaryOp -> Wrap
215 | Compare direction => W (\b, l, x, y => CompareOp.create b l x y (cast direction))
216 | Add => W AddOp.create
217 | Div => W DivOp.create
218 | Max => W MaxOp.create
219 | Min => W MinOp.create
220 | Mul => W MulOp.create
221 | Pow => W PowOp.create
222 | Rem => W RemOp.create
223 | Sub => W SubtractOp.create
224 | And => W AndOp.create
225 | Or => W OrOp.create
226 | ShiftRightLogical => W ShiftRightLogicalOp.create
227 | iop (If resTy pred true false) = do
228 | op <- IfOp.create builder uloc !(itype ctx resTy) !(ivalue pred)
229 | bodyT <- emplaceBlock $
getTrueBranch op
230 | bodyF <- emplaceBlock $
getFalseBranch op
231 | interpret cache ctx uloc bodyT true StablehloOps.ReturnOp.create
232 | interpret cache ctx uloc bodyF false StablehloOps.ReturnOp.create
234 | iop (While cond body inits) = do
235 | inits <- mkValueRange =<< traverse ivalue (toList inits)
236 | op <- WhileOp.create builder uloc inits
237 | cond' <- emplaceBlock $
getCond op
238 | interpret cache ctx uloc cond' cond StablehloOps.ReturnOp.create
239 | body' <- emplaceBlock $
getBody op
240 | interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
242 | iop (Reduce body inits axes xs) = do
243 | inits <- mkValueRange !(traverse ivalue $
toList inits)
244 | xs <- mkValueRange !(traverse ivalue $
toList xs)
245 | op <- ReduceOp.create builder uloc xs inits axes
246 | body' <- emplaceBlock $
getBody op
247 | interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
249 | iop (Slice starts stops strides x) = do
250 | OpLike <$> SliceOp.create builder uloc !(ivalue x) starts stops strides
251 | iop (DynamicSlice starts sizes x) = do
252 | starts <- mkValueRange !(traverse ivalue starts)
253 | OpLike <$> DynamicSliceOp.create builder uloc !(ivalue x) starts sizes
254 | iop (Cholesky x) = OpLike <$> CholeskyOp.create builder uloc !(ivalue x) True
255 | iop (Concat axis xs) = do
256 | xs <- mkValueRange =<< traverse ivalue (toList xs)
257 | OpLike <$> ConcatenateOp.create builder uloc xs axis
258 | iop (Iota shape dtype dim) = do
259 | resultType <- cast <$> RankedTensorType.get shape !(mlirType ctx dtype)
260 | OpLike <$> IotaOp.create builder uloc resultType dim
261 | iop (DotGeneral lb rb lc rc resultType lhs rhs) = do
262 | ddn <- DotDimensionNumbersAttr.get ctx lb rb lc rc
263 | resTy <- itype ctx resultType
264 | OpLike <$> DotGeneralOp.create builder uloc resTy !(ivalue lhs) !(ivalue rhs) ddn
265 | iop (Map f xs resTy dims) = do
266 | xs <- mkValueRange =<< traverse ivalue (toList xs)
267 | op <- MapOp.create builder uloc !(itype ctx resTy) xs dims
268 | f' <- emplaceBlock $
getComputation op
269 | interpret cache ctx uloc f' f StablehloOps.ReturnOp.create
271 | iop (Reshape dtype to x) = do
272 | resultType <- cast <$> RankedTensorType.get to !(mlirType ctx dtype)
273 | OpLike <$> ReshapeOp.create builder uloc resultType !(ivalue x)
274 | iop (Select pred true false) = do
275 | OpLike <$> SelectOp.create builder uloc !(ivalue pred) !(ivalue true) !(ivalue false)
276 | iop (Sort comp axis isStable x) = do
277 | op <- SortOp.create builder uloc !(ivalue x) axis isStable
278 | comp' <- emplaceBlock $
getComparator op
279 | comparator <- interpret cache ctx uloc comp' comp StablehloOps.ReturnOp.create
281 | iop (Reverse axes x) = OpLike <$> ReverseOp.create builder uloc !(ivalue x) axes
282 | iop (Transpose ordering x) = OpLike <$> TransposeOp.create builder uloc !(ivalue x) ordering
283 | iop (TriangularSolve a b lower) =
284 | OpLike <$> TriangularSolveOp.create
285 | builder uloc !(ivalue a) !(ivalue b) True lower False NoTranspose
286 | iop (Rng state resultType) = do
287 | stateTy <- itype ctx $
TensorType [2] U64
288 | OpLike <$> RngBitGeneratorOp.create
289 | builder uloc stateTy !(itype ctx resultType) ThreeFry !(ivalue state)
293 | execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $
Vect outputs Literal
294 | execute (MkDevice mlirCtx passManager api client) f shapes = do
295 | uloc <- UnknownLoc.get mlirCtx
296 | cache <- newArray $
cast $
counter f.env
297 | moduleOp <- ModuleOp.create uloc "root"
298 | fnType <- FunctionType.get mlirCtx !(itypes mlirCtx f.paramTypes) !(itypes mlirCtx f.resultTypes)
299 | fn <- FuncOp.create uloc "main" fnType
300 | interpret cache mlirCtx uloc !(addEntryBlock fn) f FuncOps.ReturnOp.create
301 | pushBack moduleOp !(getOperation fn)
303 | True <- run passManager !(getOperation moduleOp)
304 | | _ => throwE $
MlirPassError "Failed to run MLIR passes"
306 | version <- toString !getCurrentVersion
307 | True <- serializePortableArtifact moduleOp version !(rawStringOStream code)
308 | | False => throwE $
InvalidHloError "Failed to serialize MLIR for version \{c_str version}"
309 | bimapEitherT PjrtErr id $
do
310 | executableBuildOptions <- mkExecutableBuildOptions
311 | compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions)
312 | program <- mkPjrtProgram code
313 | loadedExec <- pjrtClientCompile api client program compileOptions
316 | delete compileOptions
317 | delete executableBuildOptions
319 | buffers <- pjrtLoadedExecutableExecute api loadedExec outputs
320 | pjrtLoadedExecutableDestroy api loadedExec
322 | for (zip buffers shapes) $
\(buffer, shape) => do
323 | literal <- allocLiteral shape
328 | event <- pjrtBufferToHostBuffer api buffer literal
329 | pjrtEventAwait api event
330 | pjrtEventDestroy api event
331 | pjrtBufferDestroy api buffer