0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.Eval
 18 |
 19 | import Data.IOArray
 20 |
 21 | import Compiler.Enzyme.MLIR.Dialect.Ops
 22 | import Compiler.LLVM.ADT.APFloat
 23 | import Compiler.LLVM.ADT.APInt
 24 | import Compiler.LLVM.Support.RawOStream
 25 | import Compiler.MLIR.Dialect.Func.IR.FuncOps
 26 | import Compiler.MLIR.IR
 27 | import Compiler.MLIR.Pass.PassManager
 28 | import Compiler.Stablehlo.Dialect.ChloOps
 29 | import Compiler.Stablehlo.Dialect.Serialization
 30 | import Compiler.Stablehlo.Dialect.StablehloAttrs
 31 | import Compiler.Stablehlo.Dialect.StablehloEnums
 32 | import Compiler.Stablehlo.Dialect.StablehloOps
 33 | import Compiler.Stablehlo.Dialect.Version
 34 | import Compiler.Xla.Client.ExecutableBuildOptions
 35 | import Compiler.Xla.HLO.Translate.HloToMhlo.HloUtils
 36 | import Compiler.Xla.PJRT.C.PjrtCApi
 37 | import Compiler.Xla.PJRT.PjrtExecutable
 38 | import Compiler.Xla.Shape
 39 | import Compiler.Xla.ShapeUtil
 40 | import Compiler.IR
 41 | import Compiler.FFI
 42 | import Compiler.LiteralRW
 43 | import Literal
 44 | import Primitive
 45 | import Util
 46 | import Device
 47 |
 48 | export
 49 | data Err
 50 |   = OutOfBounds Nat Nat
 51 |   | ValueNotFound Nat
 52 |   | PjrtErr PjrtError
 53 |   | MlirPassError String
 54 |   | InvalidHloError String
 55 |
 56 | data BoundSet : Type where
 57 |   Parameters : Block -> BoundSet
 58 |   OpLike : {auto iface : Op a} -> a -> BoundSet
 59 |
 60 | export
 61 | Show Err where
 62 |   show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}"
 63 |   show (ValueNotFound idx) = "Value not found at index \{show idx}"
 64 |   show (PjrtErr err) = show err
 65 |   show (MlirPassError err) = "MlirPassError: \{err}"
 66 |   show (InvalidHloError err) = "InvalidHloError: \{err}"
 67 |
 68 | public export 0
 69 | ErrIO : Type -> Type
 70 | ErrIO = EitherT Err IO
 71 |
 72 | set : IOArray a -> Nat -> a -> ErrIO ()
 73 | set cache idx x = do
 74 |   False <- writeArray cache (cast idx) x | True => right ()
 75 |   left $ OutOfBounds idx (cast $ max cache)
 76 |
 77 | get : IOArray a -> Nat -> ErrIO a
 78 | get cache idx = do
 79 |   Nothing <- readArray cache (cast idx) | Just x => right x
 80 |   let max = cast (max cache)
 81 |   left $ if idx >= max then OutOfBounds idx max else ValueNotFound idx
 82 |
 83 | itype : MLIRContext -> ValueType -> ErrIO Type_
 84 | itype ctx (TensorType shape dtype) = cast <$> RankedTensorType.get shape !(mlirType ctx {dtype})
 85 |
 86 | itypes : MLIRContext -> Vect n ValueType -> ErrIO TypeRange
 87 | itypes ctx types = mkTypeRange =<< traverse (itype ctx) (toList types)
 88 |
 89 | 0 Finalizer : Type -> Type
 90 | Finalizer a = OpBuilder -> Location -> ValueRange -> ErrIO a
 91 |
 92 | covering
 93 | interpret :
 94 |   IOArray BoundSet -> MLIRContext -> Location -> Block -> Fn arity -> Finalizer a -> ErrIO ()
 95 | interpret cache ctx uloc block f finalizer = do
 96 |   builder <- atBlockEnd block
 97 |   set cache f.tag (Parameters block)
 98 |   ignore $ for f.paramTypes $ \t => addArgument block !(itype ctx t) uloc
 99 |   traverse_ (\(i, expr) => set cache i !(iop expr)) (toList f.env)
100 |   results <- traverse ivalue f.results
101 |   ignore $ finalizer builder uloc !(mkValueRange $ toList results)
102 |
103 |   where
104 |
105 |   iop : {auto builder : OpBuilder} -> Op -> ErrIO BoundSet
106 |
107 |   ivalue : {auto builder : OpBuilder} -> IR.Value -> ErrIO Value.Value
108 |   ivalue (V pos op) = iop op >>= \case
109 |     Parameters block => cast <$> getArgument block pos
110 |     OpLike x {iface} => cast <$> (flip getOpResult pos =<< getOperation x)
111 |
112 |   iop (BoundSet x) = get cache x
113 |   iop (Grad shape f x) = do
114 |     revInit <- ivalue $ V 0 $ Lit {shape = [], dtype = F64} 1.0
115 |     args <- mkValueRange [!(ivalue x), cast revInit]
116 |     retTys <- mkTypeRange [!(itype ctx $ TensorType shape F64)]
117 |     op <- AutoDiffRegionOp.create builder ctx uloc retTys args EnzymeActive EnzymeActivenoneed
118 |     body <- emplaceBlock $ getBody op
119 |     interpret cache ctx uloc body f YieldOp.create
120 |     pure $ OpLike op
121 |   iop (MinValue dtype) = do
122 |     apInt <-
123 |       if isSigned {dtype}
124 |       then getSignedMinValue (numBits {dtype})
125 |       else getMinValue (numBits {dtype})
126 |     type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype})
127 |     attr <- APInt.get type apInt
128 |     OpLike <$> ConstantOp.create builder uloc attr
129 |   iop (MaxValue dtype) = do
130 |     apInt <-
131 |       if isSigned {dtype}
132 |       then getSignedMaxValue (numBits {dtype})
133 |       else getMaxValue (numBits {dtype})
134 |     type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype})
135 |     attr <- APInt.get type apInt
136 |     OpLike <$> ConstantOp.create builder uloc attr
137 |   iop MinFiniteFloat = do
138 |     type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype = F64})
139 |     attr <- APFloat.get type !(getLargest True)
140 |     OpLike <$> ConstantOp.create builder uloc attr
141 |   iop MaxFiniteFloat = do
142 |     type <- cast <$> RankedTensorType.get [] !(mlirType ctx {dtype = F64})
143 |     attr <- APFloat.get type !(getLargest False)
144 |     OpLike <$> ConstantOp.create builder uloc attr
145 |   iop (Lit {shape, dtype} lit) = do
146 |     attr <- createDenseElementsAttrFromLiteral !(write {dtype} lit) builder
147 |     OpLike <$> ConstantOp.create builder uloc attr
148 |   iop (Broadcast {dtype} from to x) =
149 |     if elem 0 to && from /= to
150 |       then do
151 |         shape <- mkShape {dtype} to
152 |         literal <- allocLiteral shape
153 |         attr <- createDenseElementsAttrFromLiteral literal builder
154 |         OpLike <$> ConstantOp.create builder uloc attr
155 |       else
156 |       let broadcastDims = Prelude.map (+ length to `minus` length from) $ List.range $ length from
157 |        in do
158 |         resTy <- itype ctx $ TensorType to dtype
159 |         OpLike <$> BroadcastInDimOp.create builder uloc resTy !(ivalue x) broadcastDims
160 |   iop (UnaryElementwise f x) = do
161 |     let W mkop @{iface} = f.create
162 |     OpLike <$> mkop builder uloc !(ivalue x)
163 |
164 |     where
165 |
166 |     data Wrap : Type where
167 |       W : forall a . (OpBuilder -> Location -> Value.Value -> ErrIO a) -> Op a => Wrap
168 |
169 |     (.create) : UnaryOp -> Wrap
170 |     (.create) = \case
171 |       Abs        => W AbsOp.create
172 |       Ceil       => W CeilOp.create
173 |       Cos        => W CosineOp.create
174 |       Exp        => W ExpOp.create
175 |       Floor      => W FloorOp.create
176 |       Log        => W LogOp.create
177 |       Logistic   => W LogisticOp.create
178 |       Not        => W NotOp.create
179 |       Neg        => W NegOp.create
180 |       Sin        => W SineOp.create
181 |       Sqrt       => W SqrtOp.create
182 |       Tan        => W TanOp.create
183 |       Tanh       => W TanhOp.create
184 |
185 |       Acos       => W AcosOp.create
186 |       Acosh      => W AcoshOp.create
187 |       Asin       => W AsinOp.create
188 |       Asinh      => W AsinhOp.create
189 |       Atan       => W AtanOp.create
190 |       Atanh      => W AtanhOp.create
191 |       Cosh       => W CoshOp.create
192 |       Sinh       => W SinhOp.create
193 |       Erf        => W ErfOp.create
194 |       ErfInv     => W ErfInvOp.create
195 |       Square     => W SquareOp.create
196 |   iop (Convert {dtype} resultShape x) = do
197 |     resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx {dtype})
198 |     OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
199 |   iop (BitCastConvert {dtype} resultShape x) = do
200 |     resultType <- cast <$> RankedTensorType.get resultShape !(mlirType ctx {dtype})
201 |     OpLike <$> ConvertOp.create builder uloc resultType !(ivalue x)
202 |   iop (BinaryElementwise f lhs rhs) = do
203 |     let W mkop @{iface} = f.create
204 |     OpLike <$> mkop builder uloc !(ivalue lhs) !(ivalue rhs)
205 |
206 |     where
207 |
208 |     data Wrap : Type where
209 |       W : forall a . (OpBuilder -> Location -> Value.Value -> Value.Value -> ErrIO a) -> Op a => Wrap
210 |
211 |     (.create) : BinaryOp -> Wrap
212 |     (.create) = \case
213 |       Compare direction => W (\b, l, x, y => CompareOp.create b l x y (cast direction))
214 |       Add => W AddOp.create
215 |       Div => W DivOp.create
216 |       Max => W MaxOp.create
217 |       Min => W MinOp.create
218 |       Mul => W MulOp.create
219 |       Pow => W PowOp.create
220 |       Rem => W RemOp.create
221 |       Sub => W SubtractOp.create
222 |       And => W AndOp.create
223 |       Or  => W OrOp.create
224 |       ShiftRightLogical => W ShiftRightLogicalOp.create
225 |   iop (If resTy pred true false) = do
226 |     op <- IfOp.create builder uloc !(itype ctx resTy) !(ivalue pred)
227 |     bodyT <- emplaceBlock $ getTrueBranch op
228 |     bodyF <- emplaceBlock $ getFalseBranch op
229 |     interpret cache ctx uloc bodyT true StablehloOps.ReturnOp.create
230 |     interpret cache ctx uloc bodyF false StablehloOps.ReturnOp.create
231 |     pure $ OpLike op
232 |   iop (While cond body inits) = do
233 |     inits <- mkValueRange =<< traverse ivalue (toList inits)
234 |     op <- WhileOp.create builder uloc inits
235 |     cond' <- emplaceBlock $ getCond op
236 |     interpret cache ctx uloc cond' cond StablehloOps.ReturnOp.create
237 |     body' <- emplaceBlock $ getBody op
238 |     interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
239 |     pure $ OpLike op
240 |   iop (Reduce body inits axes xs) = do
241 |     inits <- mkValueRange !(traverse ivalue $ toList inits)
242 |     xs <- mkValueRange !(traverse ivalue $ toList xs)
243 |     op <- ReduceOp.create builder uloc xs inits axes
244 |     body' <- emplaceBlock $ getBody op
245 |     interpret cache ctx uloc body' body StablehloOps.ReturnOp.create
246 |     pure $ OpLike op
247 |   iop (Slice starts stops strides x) = do
248 |     OpLike <$> SliceOp.create builder uloc !(ivalue x) starts stops strides
249 |   iop (DynamicSlice starts sizes x) = do
250 |     starts <- mkValueRange !(traverse ivalue starts)
251 |     OpLike <$> DynamicSliceOp.create builder uloc !(ivalue x) starts sizes
252 |   iop (Cholesky x) = OpLike <$> CholeskyOp.create builder uloc !(ivalue x) True
253 |   iop (Concat axis xs) = do
254 |     xs <- mkValueRange =<< traverse ivalue (toList xs)
255 |     OpLike <$> ConcatenateOp.create builder uloc xs axis
256 |   iop (Iota {dtype} shape dim) = do
257 |     resultType <- cast <$> RankedTensorType.get shape !(mlirType ctx {dtype})
258 |     OpLike <$> IotaOp.create builder uloc resultType dim
259 |   iop (DotGeneral lb rb lc rc resultType lhs rhs) = do
260 |     ddn <- DotDimensionNumbersAttr.get ctx lb rb lc rc
261 |     resTy <- itype ctx resultType
262 |     OpLike <$> DotGeneralOp.create builder uloc resTy !(ivalue lhs) !(ivalue rhs) ddn
263 |   iop (Map f xs resTy dims) = do
264 |     xs <- mkValueRange =<< traverse ivalue (toList xs)
265 |     op <- MapOp.create builder uloc !(itype ctx resTy) xs dims
266 |     f' <- emplaceBlock $ getComputation op
267 |     interpret cache ctx uloc f' f StablehloOps.ReturnOp.create
268 |     pure $ OpLike op
269 |   iop (Reshape {dtype} to x) = do
270 |     resultType <- cast <$> RankedTensorType.get to !(mlirType ctx {dtype})
271 |     OpLike <$> ReshapeOp.create builder uloc resultType !(ivalue x)
272 |   iop (Select pred true false) = do
273 |     OpLike <$> SelectOp.create builder uloc !(ivalue pred) !(ivalue true) !(ivalue false)
274 |   iop (Sort comp axis isStable x) = do
275 |     op <- SortOp.create builder uloc !(ivalue x) axis isStable
276 |     comp' <- emplaceBlock $ getComparator op
277 |     comparator <- interpret cache ctx uloc comp' comp StablehloOps.ReturnOp.create
278 |     pure $ OpLike op
279 |   iop (Reverse axes x) = OpLike <$> ReverseOp.create builder uloc !(ivalue x) axes
280 |   iop (Transpose ordering x) = OpLike <$> TransposeOp.create builder uloc !(ivalue x) ordering
281 |   iop (TriangularSolve a b lower) =
282 |     OpLike <$> TriangularSolveOp.create
283 |       builder uloc !(ivalue a) !(ivalue b) True lower False NoTranspose
284 |   iop (Rng state resultType) = do
285 |     stateTy <- itype ctx $ TensorType [2] U64
286 |     OpLike <$> RngBitGeneratorOp.create
287 |       builder uloc stateTy !(itype ctx resultType) ThreeFry !(ivalue state)
288 |
289 | ||| It is up to the caller to free the `Literal`s.
290 | export covering
291 | execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal
292 | execute (MkDevice mlirCtx passManager api client) f shapes = do
293 |   uloc <- UnknownLoc.get mlirCtx
294 |   cache <- newArray $ cast $ counter f.env
295 |   moduleOp <- ModuleOp.create uloc "root"
296 |   fnType <- FunctionType.get mlirCtx !(itypes mlirCtx f.paramTypes) !(itypes mlirCtx f.resultTypes)
297 |   fn <- FuncOp.create uloc "main" fnType
298 |   interpret cache mlirCtx uloc !(addEntryBlock fn) f FuncOps.ReturnOp.create
299 |   pushBack moduleOp !(getOperation fn)
300 |
301 |   True <- run passManager !(getOperation moduleOp)
302 |     | _ => throwE $ MlirPassError "Failed to run MLIR passes"
303 |   code <- cppString
304 |   version <- toString !getCurrentVersion
305 |   True <- serializePortableArtifact moduleOp version !(rawStringOStream code)
306 |     | False => throwE $ InvalidHloError "Failed to serialize MLIR for version \{c_str version}"
307 |   bimapEitherT PjrtErr id $ do
308 |     executableBuildOptions <- mkExecutableBuildOptions
309 |     compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions)
310 |     program <- mkPjrtProgram code
311 |     loadedExec <- pjrtClientCompile api client program compileOptions
312 |     delete program
313 |     delete code
314 |     delete compileOptions
315 |     delete executableBuildOptions
316 |
317 |     buffers <- pjrtLoadedExecutableExecute api loadedExec outputs
318 |     pjrtLoadedExecutableDestroy api loadedExec
319 |
320 |     for (zip buffers shapes) $ \(buffer, shape) => do
321 |       literal <- allocLiteral shape
322 |       -- is this pure?
323 |       -- note we can probably avoid the difficulties around async
324 |       -- by awaiting the event in pjrtBufferToHostBuffer, thus
325 |       -- making that function synchronous
326 |       event <- pjrtBufferToHostBuffer api buffer literal
327 |       pjrtEventAwait api event
328 |       pjrtEventDestroy api event
329 |       pjrtBufferDestroy api buffer
330 |
331 |       pure literal
332 |