17 | module Compiler.Enzyme.MLIR.Dialect.Ops
20 | import Compiler.MLIR.IR.Builders
21 | import Compiler.MLIR.IR.Location
22 | import Compiler.MLIR.IR.MLIRContext
23 | import Compiler.MLIR.IR.Operation
24 | import Compiler.MLIR.IR.OpDefinition
25 | import Compiler.MLIR.IR.Region
26 | import Compiler.MLIR.IR.TypeRange
27 | import Compiler.MLIR.IR.ValueRange
29 | ffi : String -> String
30 | ffi = libxla "c/Enzyme/MLIR/Dialect/Ops.h"
32 | %foreign (ffi "AutoDiffRegionOp_delete")
33 | prim__deleteAutoDiffRegionOp : AnyPtr -> PrimIO ()
35 | %foreign (ffi "AutoDiffRegionOp_create")
36 | prim__opBuilderCreateAutoDiffRegionOp :
37 | GCAnyPtr -> AnyPtr -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> Int -> Int -> PrimIO AnyPtr
40 | data AutoDiffRegionOp = MkAutoDiffRegionOp GCAnyPtr
48 | | EnzymeActivenoneed
51 | Cast Activity Int where
56 | EnzymeDupnoneed => 3
57 | EnzymeActivenoneed => 4
58 | EnzymeConstnoneed => 5
60 | namespace AutoDiffRegionOp
73 | (MkOpBuilder builder)
75 | (MkLocation location)
76 | (MkTypeRange outputs)
77 | (MkValueRange inputs)
80 | op <- primIO $
prim__opBuilderCreateAutoDiffRegionOp
81 | builder ctx location outputs inputs (cast activity) (cast retActivity)
82 | op <- onCollectAny' op (primIO . prim__deleteAutoDiffRegionOp)
83 | pure (MkAutoDiffRegionOp op)
85 | %foreign (ffi "AutoDiffRegionOp_getBody")
86 | prim__autoDiffRegionOpGetBody : GCAnyPtr -> AnyPtr
89 | getBody : AutoDiffRegionOp -> Region
90 | getBody (MkAutoDiffRegionOp op) = MkRegion $
prim__autoDiffRegionOpGetBody op
92 | %foreign (ffi "AutoDiffRegionOp_getOperation")
93 | prim__autoDiffRegionOpGetOperation : GCAnyPtr -> PrimIO AnyPtr
96 | Op AutoDiffRegionOp where
97 | getOperation (MkAutoDiffRegionOp op) =
98 | MkOperation <$> (primIO $
prim__autoDiffRegionOpGetOperation op)
100 | %foreign (ffi "YieldOp_delete")
101 | prim__deleteYieldOp : AnyPtr -> PrimIO ()
103 | %foreign (ffi "YieldOp_create")
104 | prim__opBuilderCreateYieldOp : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
107 | data YieldOp = MkYieldOp GCAnyPtr
111 | create : HasIO io => OpBuilder -> Location -> ValueRange -> io YieldOp
112 | create (MkOpBuilder builder) (MkLocation location) (MkValueRange operands) = do
113 | op <- primIO $
prim__opBuilderCreateYieldOp builder location operands
114 | op <- onCollectAny' op (primIO . prim__deleteYieldOp)
115 | pure (MkYieldOp op)