17 | module Compiler.MLIR.IR.BuiltinTypes
19 | import Compiler.MLIR.IR.BuiltinTypeInterfaces
20 | import Compiler.MLIR.IR.MLIRContext
21 | import Compiler.MLIR.IR.Operation
22 | import Compiler.MLIR.IR.Types
23 | import Compiler.MLIR.IR.TypeRange
26 | ffi : String -> String
27 | ffi = libxla "c/mlir/IR/BuiltinTypes.h"
30 | data FloatType = MkFloatType GCAnyPtr
33 | %foreign (ffi "FloatType_delete")
34 | prim__deleteFloatType : AnyPtr -> PrimIO ()
36 | %foreign (ffi "set_array_FloatType")
37 | prim__setArrayFloatType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
40 | Cast FloatType Type_ where
41 | cast (MkFloatType t) = MkType_ t prim__setArrayFloatType
43 | %foreign (ffi "Float64Type_get")
44 | prim__float64TypeGet : AnyPtr -> PrimIO AnyPtr
46 | namespace Float64Type
48 | get : HasIO io => MLIRContext -> io FloatType
49 | get (MkMLIRContext ctx) = do
50 | ftype <- primIO $
prim__float64TypeGet ctx
51 | ftype <- onCollectAny' ftype (primIO . prim__deleteFloatType)
52 | pure (MkFloatType ftype)
55 | data FunctionType = MkFunctionType GCAnyPtr
57 | %foreign (ffi "FunctionType_delete")
58 | prim__deleteFunctionType : AnyPtr -> PrimIO ()
60 | %foreign (ffi "FunctionType_get")
61 | prim__functionTypeGet : AnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
63 | namespace FunctionType
65 | get : HasIO io => MLIRContext -> TypeRange -> TypeRange -> io FunctionType
66 | get (MkMLIRContext ctx) (MkTypeRange inputs) (MkTypeRange results) = do
67 | ftype <- primIO $
prim__functionTypeGet ctx inputs results
68 | ftype <- onCollectAny' ftype (primIO . prim__deleteFunctionType)
69 | pure (MkFunctionType ftype)
72 | data IntegerType = MkIntegerType GCAnyPtr
75 | %foreign (ffi "IntegerType_delete")
76 | prim__deleteIntegerType : AnyPtr -> PrimIO ()
78 | %foreign (ffi "set_array_IntegerType")
79 | prim__setArrayIntegerType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
82 | Cast IntegerType Type_ where
83 | cast (MkIntegerType t) = MkType_ t prim__setArrayIntegerType
85 | %foreign (ffi "IntegerType_get")
86 | prim__integerTypeGet : AnyPtr -> Bits64 -> Int -> PrimIO AnyPtr
89 | data SignednessSemantics = Signless | Signed | Unsigned
92 | Cast SignednessSemantics Int where
97 | namespace IntegerType
99 | get : HasIO io => MLIRContext -> Nat -> SignednessSemantics -> io IntegerType
100 | get (MkMLIRContext ctx) width signedness = do
101 | ftype <- primIO $
prim__integerTypeGet ctx (cast width) (cast signedness)
102 | ftype <- onCollectAny' ftype (primIO . prim__deleteIntegerType)
103 | pure (MkIntegerType ftype)
106 | data RankedTensorType = MkRankedTensorType GCAnyPtr
108 | %foreign (ffi "set_array_RankedTensorType")
109 | prim__setArrayRankedTensorType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
111 | %foreign (ffi "RankedTensorType_delete")
112 | prim__deleteRankedTensorType : AnyPtr -> PrimIO ()
114 | namespace ShapedType
116 | Cast RankedTensorType ShapedType where
117 | cast (MkRankedTensorType t) = MkShapedType t
121 | Cast RankedTensorType Type_ where
122 | cast (MkRankedTensorType t) = MkType_ t prim__setArrayRankedTensorType
124 | %foreign (ffi "RankedTensorType_get")
125 | prim__rankedTensorTypeGet : GCPtr Int64 -> Bits64 -> GCAnyPtr -> PrimIO AnyPtr
127 | namespace RankedTensorType
129 | get : HasIO io => List Nat -> Types.Type_ -> io RankedTensorType
130 | get shape (MkType_ elementType _) = do
131 | MkInt64Array arr <- mkInt64Array (map cast shape)
132 | rtt <- primIO $
prim__rankedTensorTypeGet arr (cast $
length shape) elementType
133 | rtt <- onCollectAny' rtt (primIO . prim__deleteRankedTensorType)
134 | pure (MkRankedTensorType rtt)