0 | {--
 1 | Copyright (C) 2026  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| For internal spidr use only.
17 | module Compiler.Stablehlo.Dialect.StablehloAttrs
18 |
19 | import Compiler.FFI
20 | import Compiler.MLIR.IR.MLIRContext
21 |
22 | ffi : String -> String
23 | ffi = libxla "c/stablehlo/dialect/StablehloAttrs.h"
24 |
25 | %foreign (ffi "DotDimensionNumbersAttr_delete")
26 | prim__deleteDotDimensionNumbersAttr : AnyPtr -> PrimIO ()
27 |
28 | %foreign (ffi "DotDimensionNumbersAttr_get")
29 | prim__dotDimensionNumbersAttrGet :
30 |   AnyPtr ->
31 |   GCPtr Int64 -> Bits64 ->
32 |   GCPtr Int64 -> Bits64 ->
33 |   GCPtr Int64 -> Bits64 ->
34 |   GCPtr Int64 -> Bits64 ->
35 |   PrimIO AnyPtr
36 |
37 | public export
38 | data DotDimensionNumbersAttr = MkDotDimensionNumbersAttr GCAnyPtr
39 |
40 | namespace DotDimensionNumbersAttr
41 |   export
42 |   get :
43 |     HasIO io =>
44 |     MLIRContext ->
45 |     List Nat ->
46 |     List Nat ->
47 |     List Nat ->
48 |     List Nat ->
49 |     io DotDimensionNumbersAttr
50 |   get
51 |     (MkMLIRContext ctx)
52 |     lhsBatchingDimensions
53 |     rhsBatchingDimensions
54 |     lhsContractingDimensions
55 |     rhsContractingDimensions = do
56 |       MkInt64Array lhsBatchingDimensions' <- mkInt64Array $ map cast lhsBatchingDimensions
57 |       MkInt64Array rhsBatchingDimensions' <- mkInt64Array $ map cast rhsBatchingDimensions
58 |       MkInt64Array lhsContractingDimensions' <- mkInt64Array $ map cast lhsContractingDimensions
59 |       MkInt64Array rhsContractingDimensions' <- mkInt64Array $ map cast rhsContractingDimensions
60 |       attr <- primIO $ prim__dotDimensionNumbersAttrGet
61 |         ctx
62 |         lhsBatchingDimensions' (cast $ length lhsBatchingDimensions)
63 |         rhsBatchingDimensions' (cast $ length rhsBatchingDimensions)
64 |         lhsContractingDimensions' (cast $ length lhsContractingDimensions)
65 |         rhsContractingDimensions' (cast $ length rhsContractingDimensions)
66 |       attr <- onCollectAny' attr (primIO . prim__deleteDotDimensionNumbersAttr)
67 |       pure (MkDotDimensionNumbersAttr attr)
68 |