17 | module Compiler.Xla.Literal
19 | import Compiler.Xla.Shape
20 | import Compiler.Xla.ShapeUtil
21 | import Compiler.Xla.XlaData
25 | ffi : String -> String
26 | ffi = libxla "c/xla/literal.h"
30 | data Literal : Type where
31 | MkLiteral : GCAnyPtr -> Literal
33 | %foreign (ffi "Literal_delete")
34 | prim__delete : AnyPtr -> PrimIO ()
36 | %foreign (ffi "Literal_new")
37 | prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr
40 | allocLiteral : HasIO io => Xla.Shape -> io Literal
41 | allocLiteral (MkShape shape) = do
42 | litPtr <- primIO $
prim__allocLiteral shape
43 | litPtr <- onCollectAny' litPtr (primIO . prim__delete)
44 | pure (MkLiteral litPtr)
47 | %foreign (ffi "Literal_size_bytes")
48 | prim__literalSizeBytes : GCAnyPtr -> Int64
51 | %foreign (ffi "Literal_untyped_data")
52 | prim__literalUntypedData : GCAnyPtr -> AnyPtr
55 | %foreign (ffi "Literal_Set_bool")
56 | prim__literalSetBool : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int -> PrimIO ()
59 | set : Literal -> Int64Array -> Bool -> IO ()
60 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
61 | primIO $
prim__literalSetBool lit idxs idxsLen (boolToCInt value)
63 | %foreign (ffi "Literal_Get_bool")
64 | prim__literalGetBool : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int
67 | get : Literal -> Int64Array -> Bool
68 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) =
69 | cIntToBool $
prim__literalGetBool lit idxs idxsLen
72 | %foreign (ffi "Literal_Set_double")
73 | prim__literalSetDouble : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Double -> PrimIO ()
76 | set : Literal -> Int64Array -> Double -> IO ()
77 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
78 | primIO $
prim__literalSetDouble lit idxs idxsLen value
80 | %foreign (ffi "Literal_Get_double")
81 | prim__literalGetDouble : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Double
84 | get : Literal -> Int64Array -> Double
85 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetDouble lit idxs idxsLen
88 | %foreign (ffi "Literal_Set_int32_t")
89 | prim__literalSetInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int32 -> PrimIO ()
92 | set : Literal -> Int64Array -> Int32 -> IO ()
93 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
94 | primIO $
prim__literalSetInt32t lit idxs idxsLen value
96 | %foreign (ffi "Literal_Get_int32_t")
97 | prim__literalGetInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int32
100 | get : Literal -> Int64Array -> Int32
101 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetInt32t lit idxs idxsLen
104 | %foreign (ffi "Literal_Set_int64_t")
105 | prim__literalSetInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int64 -> PrimIO ()
108 | set : Literal -> Int64Array -> Int64 -> IO ()
109 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
110 | primIO $
prim__literalSetInt64t lit idxs idxsLen value
112 | %foreign (ffi "Literal_Get_int64_t")
113 | prim__literalGetInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int64
116 | get : Literal -> Int64Array -> Int64
117 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetInt64t lit idxs idxsLen
120 | %foreign (ffi "Literal_Set_uint32_t")
121 | prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits32 -> PrimIO ()
124 | set : Literal -> Int64Array -> Bits32 -> IO ()
125 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
126 | primIO $
prim__literalSetUInt32t lit idxs idxsLen value
128 | %foreign (ffi "Literal_Get_uint32_t")
129 | prim__literalGetUInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits32
132 | get : Literal -> Int64Array -> Bits32
133 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetUInt32t lit idxs idxsLen
136 | %foreign (ffi "Literal_Set_uint64_t")
137 | prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits64 -> PrimIO ()
140 | set : Literal -> Int64Array -> Bits64 -> IO ()
141 | set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
142 | primIO $
prim__literalSetUInt64t lit idxs idxsLen value
144 | %foreign (ffi "Literal_Get_uint64_t")
145 | prim__literalGetUInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits64
148 | get : Literal -> Int64Array -> Bits64
149 | get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetUInt64t lit idxs idxsLen