17 | module Compiler.Stablehlo.Dialect.StablehloAttrs
20 | import Compiler.MLIR.IR.MLIRContext
22 | ffi : String -> String
23 | ffi = libxla "c/stablehlo/dialect/StablehloAttrs.h"
25 | %foreign (ffi "DotDimensionNumbersAttr_delete")
26 | prim__deleteDotDimensionNumbersAttr : AnyPtr -> PrimIO ()
28 | %foreign (ffi "DotDimensionNumbersAttr_get")
29 | prim__dotDimensionNumbersAttrGet :
31 | GCPtr Int64 -> Bits64 ->
32 | GCPtr Int64 -> Bits64 ->
33 | GCPtr Int64 -> Bits64 ->
34 | GCPtr Int64 -> Bits64 ->
38 | data DotDimensionNumbersAttr = MkDotDimensionNumbersAttr GCAnyPtr
40 | namespace DotDimensionNumbersAttr
49 | io DotDimensionNumbersAttr
52 | lhsBatchingDimensions
53 | rhsBatchingDimensions
54 | lhsContractingDimensions
55 | rhsContractingDimensions = do
56 | MkInt64Array lb lbLen <- mkInt64Array $
map cast lhsBatchingDimensions
57 | MkInt64Array rb rbLen <- mkInt64Array $
map cast rhsBatchingDimensions
58 | MkInt64Array lc lcLen <- mkInt64Array $
map cast lhsContractingDimensions
59 | MkInt64Array rc rcLen <- mkInt64Array $
map cast rhsContractingDimensions
60 | attr <- primIO $
prim__dotDimensionNumbersAttrGet ctx lb lbLen rb rbLen lc lcLen rc rcLen
61 | attr <- onCollectAny' attr (primIO . prim__deleteDotDimensionNumbersAttr)
62 | pure (MkDotDimensionNumbersAttr attr)