0 | {--
  1 | Copyright (C) 2025  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.Enzyme.MLIR.Dialect.Ops
 18 |
 19 | import Compiler.FFI
 20 | import Compiler.MLIR.IR.Builders
 21 | import Compiler.MLIR.IR.Location
 22 | import Compiler.MLIR.IR.MLIRContext
 23 | import Compiler.MLIR.IR.Operation
 24 | import Compiler.MLIR.IR.OpDefinition
 25 | import Compiler.MLIR.IR.Region
 26 | import Compiler.MLIR.IR.TypeRange
 27 | import Compiler.MLIR.IR.ValueRange
 28 |
 29 | ffi : String -> String
 30 | ffi = libxla "c/Enzyme/MLIR/Dialect/Ops.h"
 31 |
 32 | %foreign (ffi "AutoDiffRegionOp_delete")
 33 | prim__deleteAutoDiffRegionOp : AnyPtr -> PrimIO ()
 34 |
 35 | %foreign (ffi "AutoDiffRegionOp_create")
 36 | prim__opBuilderCreateAutoDiffRegionOp :
 37 |   GCAnyPtr -> AnyPtr -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> Int -> Int -> PrimIO AnyPtr
 38 |
 39 | public export
 40 | data AutoDiffRegionOp = MkAutoDiffRegionOp GCAnyPtr
 41 |
 42 | public export
 43 | data Activity
 44 |   = EnzymeActive
 45 |   | EnzymeDup
 46 |   | EnzymeConst
 47 |   | EnzymeDupnoneed
 48 |   | EnzymeActivenoneed
 49 |   | EnzymeConstnoneed
 50 |
 51 | Cast Activity Int where
 52 |   cast = \case
 53 |     EnzymeActive       => 0
 54 |     EnzymeDup          => 1
 55 |     EnzymeConst        => 2
 56 |     EnzymeDupnoneed    => 3
 57 |     EnzymeActivenoneed => 4
 58 |     EnzymeConstnoneed  => 5
 59 |
 60 | namespace AutoDiffRegionOp
 61 |   export
 62 |   create :
 63 |     HasIO io =>
 64 |     OpBuilder ->
 65 |     MLIRContext ->
 66 |     Location ->
 67 |     TypeRange ->
 68 |     ValueRange ->
 69 |     Activity ->
 70 |     Activity ->
 71 |     io AutoDiffRegionOp
 72 |   create
 73 |     (MkOpBuilder builder)
 74 |     (MkMLIRContext ctx)
 75 |     (MkLocation location)
 76 |     (MkTypeRange outputs)
 77 |     (MkValueRange inputs)
 78 |     activity
 79 |     retActivity = do
 80 |       op <- primIO $ prim__opBuilderCreateAutoDiffRegionOp
 81 |         builder ctx location outputs inputs (cast activity) (cast retActivity)
 82 |       op <- onCollectAny' op (primIO . prim__deleteAutoDiffRegionOp)
 83 |       pure (MkAutoDiffRegionOp op)
 84 |
 85 | %foreign (ffi "AutoDiffRegionOp_getBody")
 86 | prim__autoDiffRegionOpGetBody : GCAnyPtr -> AnyPtr
 87 |
 88 | export
 89 | getBody : AutoDiffRegionOp -> Region
 90 | getBody (MkAutoDiffRegionOp op) = MkRegion $ prim__autoDiffRegionOpGetBody op
 91 |
 92 | %foreign (ffi "AutoDiffRegionOp_getOperation")
 93 | prim__autoDiffRegionOpGetOperation : GCAnyPtr -> PrimIO AnyPtr
 94 |
 95 | export
 96 | Op AutoDiffRegionOp where
 97 |   getOperation (MkAutoDiffRegionOp op) =
 98 |     MkOperation <$> (primIO $ prim__autoDiffRegionOpGetOperation op)
 99 |
100 | %foreign (ffi "YieldOp_delete")
101 | prim__deleteYieldOp : AnyPtr -> PrimIO ()
102 |
103 | %foreign (ffi "YieldOp_create")
104 | prim__opBuilderCreateYieldOp : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
105 |
106 | public export
107 | data YieldOp = MkYieldOp GCAnyPtr
108 |
109 | namespace YieldOp
110 |   export
111 |   create : HasIO io => OpBuilder -> Location -> ValueRange -> io YieldOp
112 |   create (MkOpBuilder builder) (MkLocation location) (MkValueRange operands) = do
113 |     op <- primIO $ prim__opBuilderCreateYieldOp builder location operands
114 |     op <- onCollectAny' op (primIO . prim__deleteYieldOp)
115 |     pure (MkYieldOp op)
116 |