17 | module Compiler.Stablehlo.Dialect.StablehloAttrs
20 | import Compiler.MLIR.IR.MLIRContext
22 | ffi : String -> String
23 | ffi = libxla "c/stablehlo/dialect/StablehloAttrs.h"
25 | %foreign (ffi "DotDimensionNumbersAttr_delete")
26 | prim__deleteDotDimensionNumbersAttr : AnyPtr -> PrimIO ()
28 | %foreign (ffi "DotDimensionNumbersAttr_get")
29 | prim__dotDimensionNumbersAttrGet :
31 | GCPtr Int64 -> Bits64 ->
32 | GCPtr Int64 -> Bits64 ->
33 | GCPtr Int64 -> Bits64 ->
34 | GCPtr Int64 -> Bits64 ->
38 | data DotDimensionNumbersAttr = MkDotDimensionNumbersAttr GCAnyPtr
40 | namespace DotDimensionNumbersAttr
49 | io DotDimensionNumbersAttr
52 | lhsBatchingDimensions
53 | rhsBatchingDimensions
54 | lhsContractingDimensions
55 | rhsContractingDimensions = do
56 | MkInt64Array lhsBatchingDimensions' <- mkInt64Array $
map cast lhsBatchingDimensions
57 | MkInt64Array rhsBatchingDimensions' <- mkInt64Array $
map cast rhsBatchingDimensions
58 | MkInt64Array lhsContractingDimensions' <- mkInt64Array $
map cast lhsContractingDimensions
59 | MkInt64Array rhsContractingDimensions' <- mkInt64Array $
map cast rhsContractingDimensions
60 | attr <- primIO $
prim__dotDimensionNumbersAttrGet
62 | lhsBatchingDimensions' (cast $
length lhsBatchingDimensions)
63 | rhsBatchingDimensions' (cast $
length rhsBatchingDimensions)
64 | lhsContractingDimensions' (cast $
length lhsContractingDimensions)
65 | rhsContractingDimensions' (cast $
length rhsContractingDimensions)
66 | attr <- onCollectAny' attr (primIO . prim__deleteDotDimensionNumbersAttr)
67 | pure (MkDotDimensionNumbersAttr attr)