0 | {--
 1 | Copyright (C) 2024  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| `Device` management. A `Device` encapsulates the graph compiler and hardware support.
17 | |||
18 | ||| Beyond the *type* `Device` itself, you will not likely need this module unless you are
19 | ||| developing a custom plugin.
20 | module Device
21 |
22 | import Compiler.Enzyme.MLIR.Implementations.CoreDialectsAutoDiffImplementations
23 | import Compiler.Enzyme.MLIR.Dialect.Dialect
24 | import Compiler.Enzyme.MLIR.Passes.Passes
25 | import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.XLADerivatives
26 | import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes
27 | import Compiler.MLIR.Dialect.Func.IR.FuncOps
28 | import public Compiler.MLIR.IR.DialectRegistry
29 | import public Compiler.MLIR.IR.MLIRContext
30 | import public Compiler.MLIR.Pass.PassManager
31 | import Compiler.MLIR.Transforms.Passes
32 | import Compiler.Stablehlo.Dialect.ChloOps
33 | import Compiler.Stablehlo.Dialect.StablehloOps
34 | import Compiler.Stablehlo.Transforms.Passes
35 | import Compiler.Xla.PJRT.C.PjrtCApi
36 |
37 | ||| A PJRT "device". These are required to run spidr graphs, and are provided by your plugin.
38 | |||
39 | ||| Plugin developers should create one via `createDevice`. Everyone else should use their plugin.
40 | public export
41 | record Device where
42 |   constructor MkDevice
43 |   mlirCtx : MLIRContext
44 |   passManager : PassManager
45 |   api : PjrtApi
46 |   client : PjrtClient
47 |
48 | initializeMLIRContext : HasIO io => MLIRContext -> io ()
49 | initializeMLIRContext ctx = do
50 |   loadDialectFuncDialect ctx
51 |   loadDialectStablehloDialect ctx
52 |   loadDialectChloDialect ctx
53 |   loadDialectEnzymeDialect ctx
54 |
55 |   reg <- mkDialectRegistry
56 |   registerCoreDialectAutodiffInterfaces reg
57 |   registerStableHLODialectAutoDiffInterface reg
58 |   appendDialectRegistry ctx reg
59 |
60 | initializePassManager : HasIO io => PassManager -> io ()
61 | initializePassManager pm = do
62 |   addChloLegalizeToStablehloPass pm
63 |   addOutlineEnzymeFromRegionPass pm
64 |   registerCanonicalizerPass
65 |   registerRemoveUnusedEnzymeOpsPass
66 |   addDifferentiatePass pm "canonicalize,remove-unnecessary-enzyme-ops"
67 |   addArithRaisingPass pm
68 |
69 | ||| Create and initialize a `Device`. For use by plugin developers.
70 | export
71 | createDevice : PjrtApi -> Pjrt Device
72 | createDevice api = do
73 |   mlirCtx <- mkMLIRContext
74 |   initializeMLIRContext mlirCtx
75 |   passManager <- mkPassManager mlirCtx
76 |   initializePassManager passManager
77 |   pure $ MkDevice mlirCtx passManager api !(pjrtClientCreate api)
78 |