0 | {--
  1 | Copyright (C) 2025  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.MLIR.IR.BuiltinTypes
 18 |
 19 | import Compiler.MLIR.IR.BuiltinTypeInterfaces
 20 | import Compiler.MLIR.IR.MLIRContext
 21 | import Compiler.MLIR.IR.Operation
 22 | import Compiler.MLIR.IR.Types
 23 | import Compiler.MLIR.IR.TypeRange
 24 | import Compiler.FFI
 25 |
 26 | ffi : String -> String
 27 | ffi = libxla "c/mlir/IR/BuiltinTypes.h"
 28 |
 29 | public export
 30 | data FloatType = MkFloatType GCAnyPtr
 31 |
 32 | export
 33 | %foreign (ffi "FloatType_delete")
 34 | prim__deleteFloatType : AnyPtr -> PrimIO ()
 35 |
 36 | %foreign (ffi "set_array_FloatType")
 37 | prim__setArrayFloatType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
 38 |
 39 | export
 40 | Cast FloatType Type_ where
 41 |   cast (MkFloatType t) = MkType_ t prim__setArrayFloatType
 42 |
 43 | %foreign (ffi "Float64Type_get")
 44 | prim__float64TypeGet : AnyPtr -> PrimIO AnyPtr
 45 |
 46 | namespace Float64Type
 47 |   export
 48 |   get : HasIO io => MLIRContext -> io FloatType
 49 |   get (MkMLIRContext ctx) = do
 50 |     ftype <- primIO $ prim__float64TypeGet ctx
 51 |     ftype <- onCollectAny' ftype (primIO . prim__deleteFloatType)
 52 |     pure (MkFloatType ftype)
 53 |
 54 | public export
 55 | data FunctionType = MkFunctionType GCAnyPtr
 56 |
 57 | %foreign (ffi "FunctionType_delete")
 58 | prim__deleteFunctionType : AnyPtr -> PrimIO ()
 59 |
 60 | %foreign (ffi "FunctionType_get")
 61 | prim__functionTypeGet : AnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
 62 |
 63 | namespace FunctionType
 64 |   export
 65 |   get : HasIO io => MLIRContext -> TypeRange -> TypeRange -> io FunctionType
 66 |   get (MkMLIRContext ctx) (MkTypeRange inputs) (MkTypeRange results) = do
 67 |     ftype <- primIO $ prim__functionTypeGet ctx inputs results
 68 |     ftype <- onCollectAny' ftype (primIO . prim__deleteFunctionType)
 69 |     pure (MkFunctionType ftype)
 70 |
 71 | public export
 72 | data IntegerType = MkIntegerType GCAnyPtr
 73 |
 74 | export
 75 | %foreign (ffi "IntegerType_delete")
 76 | prim__deleteIntegerType : AnyPtr -> PrimIO ()
 77 |
 78 | %foreign (ffi "set_array_IntegerType")
 79 | prim__setArrayIntegerType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
 80 |
 81 | export
 82 | Cast IntegerType Type_ where
 83 |   cast (MkIntegerType t) = MkType_ t prim__setArrayIntegerType
 84 |
 85 | %foreign (ffi "IntegerType_get")
 86 | prim__integerTypeGet : AnyPtr -> Bits64 -> Int -> PrimIO AnyPtr
 87 |
 88 | public export
 89 | data SignednessSemantics = Signless | Signed | Unsigned
 90 |
 91 | export
 92 | Cast SignednessSemantics Int where
 93 |   cast Signless = 0
 94 |   cast Signed = 1
 95 |   cast Unsigned = 2
 96 |
 97 | namespace IntegerType
 98 |   export
 99 |   get : HasIO io => MLIRContext -> Nat -> SignednessSemantics -> io IntegerType
100 |   get (MkMLIRContext ctx) width signedness = do
101 |     ftype <- primIO $ prim__integerTypeGet ctx (cast width) (cast signedness)
102 |     ftype <- onCollectAny' ftype (primIO . prim__deleteIntegerType)
103 |     pure (MkIntegerType ftype)
104 |
105 | public export
106 | data RankedTensorType = MkRankedTensorType GCAnyPtr
107 |
108 | %foreign (ffi "set_array_RankedTensorType")
109 | prim__setArrayRankedTensorType : AnyPtr -> Bits64 -> GCAnyPtr -> PrimIO ()
110 |
111 | %foreign (ffi "RankedTensorType_delete")
112 | prim__deleteRankedTensorType : AnyPtr -> PrimIO ()
113 |
114 | namespace ShapedType
115 |   export
116 |   Cast RankedTensorType ShapedType where
117 |     cast (MkRankedTensorType t) = MkShapedType t
118 |
119 | namespace Type_
120 |   export
121 |   Cast RankedTensorType Type_ where
122 |     cast (MkRankedTensorType t) = MkType_ t prim__setArrayRankedTensorType
123 |
124 | %foreign (ffi "RankedTensorType_get")
125 | prim__rankedTensorTypeGet : GCPtr Int64 -> Bits64 -> GCAnyPtr -> PrimIO AnyPtr
126 |
127 | namespace RankedTensorType
128 |   export
129 |   get : HasIO io => List Nat -> Types.Type_ -> io RankedTensorType
130 |   get shape (MkType_ elementType _) = do
131 |     MkInt64Array arr <- mkInt64Array (map cast shape)
132 |     rtt <- primIO $ prim__rankedTensorTypeGet arr (cast $ length shape) elementType
133 |     rtt <- onCollectAny' rtt (primIO . prim__deleteRankedTensorType)
134 |     pure (MkRankedTensorType rtt)
135 |