22 | import Compiler.Enzyme.MLIR.Implementations.CoreDialectsAutoDiffImplementations
23 | import Compiler.Enzyme.MLIR.Dialect.Dialect
24 | import Compiler.Enzyme.MLIR.Passes.Passes
25 | import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.XLADerivatives
26 | import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes
27 | import Compiler.MLIR.Dialect.Func.IR.FuncOps
28 | import public Compiler.MLIR.IR.DialectRegistry
29 | import public Compiler.MLIR.IR.MLIRContext
30 | import public Compiler.MLIR.Pass.PassManager
31 | import Compiler.MLIR.Transforms.Passes
32 | import Compiler.Stablehlo.Dialect.ChloOps
33 | import Compiler.Stablehlo.Dialect.StablehloOps
34 | import Compiler.Stablehlo.Transforms.Passes
35 | import Compiler.Xla.PJRT.C.PjrtCApi
42 | constructor MkDevice
43 | mlirCtx : MLIRContext
44 | passManager : PassManager
48 | initializeMLIRContext : HasIO io => MLIRContext -> io ()
49 | initializeMLIRContext ctx = do
50 | loadDialectFuncDialect ctx
51 | loadDialectStablehloDialect ctx
52 | loadDialectChloDialect ctx
53 | loadDialectEnzymeDialect ctx
55 | reg <- mkDialectRegistry
56 | registerCoreDialectAutodiffInterfaces reg
57 | registerStableHLODialectAutoDiffInterface reg
58 | appendDialectRegistry ctx reg
60 | initializePassManager : HasIO io => PassManager -> io ()
61 | initializePassManager pm = do
62 | addChloLegalizeToStablehloPass pm
63 | addOutlineEnzymeFromRegionPass pm
64 | registerCanonicalizerPass
65 | registerRemoveUnusedEnzymeOpsPass
66 | addDifferentiatePass pm "canonicalize,remove-unnecessary-enzyme-ops"
67 | addArithRaisingPass pm
71 | createDevice : PjrtApi -> Pjrt Device
72 | createDevice api = do
73 | mlirCtx <- mkMLIRContext
74 | initializeMLIRContext mlirCtx
75 | passManager <- mkPassManager mlirCtx
76 | initializePassManager passManager
77 | pure $
MkDevice mlirCtx passManager api !(pjrtClientCreate api)