0 | {--
 1 | Copyright (C) 2021  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| For internal spidr use only.
17 | module Compiler.Xla.XlaData
18 |
19 | import Compiler.MLIR.IR.BuiltinTypes
20 | import Compiler.MLIR.IR.MLIRContext
21 | import Compiler.MLIR.IR.Types
22 |
23 | export
24 | interface Primitive dtype where
25 |   repr : String
26 |   numBits : Bits16
27 |   isSigned : Bool
28 |   xlaIdentifier : Int
29 |   mlirType : HasIO io => MLIRContext -> io Type_
30 |
31 | export data PRED : Type where
32 |
33 | export
34 | Primitive PRED where
35 |   repr = "PRED"
36 |   numBits = 1
37 |   isSigned = True
38 |   xlaIdentifier = 1
39 |   mlirType ctx = cast <$> IntegerType.get ctx 1 Signless
40 |
41 | export data S32 : Type where
42 |
43 | export
44 | Primitive S32 where
45 |   repr = "S32"
46 |   numBits = 32
47 |   isSigned = True
48 |   xlaIdentifier = 4
49 |   mlirType ctx = cast <$> IntegerType.get ctx 32 Signless
50 |
51 | export data S64 : Type where
52 |
53 | export
54 | Primitive S64 where
55 |   repr = "S64"
56 |   numBits = 64
57 |   isSigned = True
58 |   xlaIdentifier = 5
59 |   mlirType ctx = cast <$> IntegerType.get ctx 64 Signless
60 |
61 | export data U32 : Type where
62 |
63 | export
64 | Primitive U32 where
65 |   repr = "U32"
66 |   numBits = 32
67 |   isSigned = False
68 |   xlaIdentifier = 8
69 |   mlirType ctx = cast <$> IntegerType.get ctx 32 Unsigned
70 |
71 | export data U64 : Type where
72 |
73 | export
74 | Primitive U64 where
75 |   repr = "U64"
76 |   numBits = 64
77 |   isSigned = False
78 |   xlaIdentifier = 9
79 |   mlirType ctx = cast <$> IntegerType.get ctx 64 Unsigned
80 |
81 | export data F64 : Type where
82 |
83 | export
84 | Primitive F64 where
85 |   repr = "F64"
86 |   numBits = 64
87 |   isSigned = True
88 |   xlaIdentifier = 12
89 |   mlirType ctx = cast <$> Float64Type.get ctx
90 |