17 | module Compiler.Xla.Literal
19 | import Compiler.Xla.Shape
20 | import Compiler.Xla.ShapeUtil
21 | import Compiler.Xla.XlaData
25 | ffi : String -> String
26 | ffi = libxla "c/xla/literal.h"
30 | data Literal : Type where
31 | MkLiteral : GCAnyPtr -> Literal
33 | %foreign (ffi "Literal_delete")
34 | prim__delete : AnyPtr -> PrimIO ()
36 | %foreign (ffi "Literal_new")
37 | prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr
40 | allocLiteral : HasIO io => Xla.Shape -> io Literal
41 | allocLiteral (MkShape shape) = do
42 | litPtr <- primIO $
prim__allocLiteral shape
43 | litPtr <- onCollectAny' litPtr (primIO . prim__delete)
44 | pure (MkLiteral litPtr)
47 | %foreign (ffi "Literal_size_bytes")
48 | prim__literalSizeBytes : GCAnyPtr -> Int
51 | %foreign (ffi "Literal_untyped_data")
52 | prim__literalUntypedData : GCAnyPtr -> AnyPtr
55 | %foreign (ffi "Literal_Set_bool")
56 | prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> Int -> PrimIO ()
59 | set : Literal -> List Nat -> Bool -> IO ()
60 | set (MkLiteral lit) idxs value = do
61 | MkIntArray idxsArrayPtr <- mkIntArray idxs
62 | primIO $
prim__literalSetBool lit idxsArrayPtr (cast $
length idxs) (boolToCInt value)
64 | %foreign (ffi "Literal_Get_bool")
65 | literalGetBool : GCAnyPtr -> GCPtr Int -> Int -> Int
68 | get : Literal -> List Nat -> Bool
69 | get (MkLiteral lit) idxs = unsafePerformIO $
do
70 | MkIntArray idxsArrayPtr <- mkIntArray idxs
71 | pure $
cIntToBool $
literalGetBool lit idxsArrayPtr (cast $
length idxs)
74 | %foreign (ffi "Literal_Set_double")
75 | prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Int -> Double -> PrimIO ()
78 | set : Literal -> List Nat -> Double -> IO ()
79 | set (MkLiteral lit) idxs value = do
80 | MkIntArray idxsArrayPtr <- mkIntArray idxs
81 | primIO $
prim__literalSetDouble lit idxsArrayPtr (cast $
length idxs) value
83 | %foreign (ffi "Literal_Get_double")
84 | literalGetDouble : GCAnyPtr -> GCPtr Int -> Int -> Double
87 | get : Literal -> List Nat -> Double
88 | get (MkLiteral lit) idxs = unsafePerformIO $
do
89 | MkIntArray idxsArrayPtr <- mkIntArray idxs
90 | pure $
literalGetDouble lit idxsArrayPtr (cast $
length idxs)
93 | %foreign (ffi "Literal_Set_int32_t")
94 | prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> Int32 -> PrimIO ()
97 | set : Literal -> List Nat -> Int32 -> IO ()
98 | set (MkLiteral lit) idxs value = do
99 | MkIntArray idxsArrayPtr <- mkIntArray idxs
100 | primIO $
prim__literalSetInt32t lit idxsArrayPtr (cast $
length idxs) value
102 | %foreign (ffi "Literal_Get_int32_t")
103 | literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int -> Int32
106 | get : Literal -> List Nat -> Int32
107 | get (MkLiteral lit) idxs = unsafePerformIO $
do
108 | MkIntArray idxsArrayPtr <- mkIntArray idxs
109 | pure $
cast $
literalGetInt32t lit idxsArrayPtr (cast $
length idxs)
112 | %foreign (ffi "Literal_Set_uint32_t")
113 | prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> Bits32 -> PrimIO ()
116 | set : Literal -> List Nat -> Nat -> IO ()
117 | set (MkLiteral lit) idxs value = do
118 | MkIntArray idxsArrayPtr <- mkIntArray idxs
119 | primIO $
prim__literalSetUInt32t lit idxsArrayPtr (cast $
length idxs) (cast value)
121 | %foreign (ffi "Literal_Get_uint32_t")
122 | literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> Bits32
125 | get : Literal -> List Nat -> Nat
126 | get (MkLiteral lit) idxs = unsafePerformIO $
do
127 | MkIntArray idxsArrayPtr <- mkIntArray idxs
128 | pure $
cast $
literalGetUInt32t lit idxsArrayPtr (cast $
length idxs)
131 | %foreign (ffi "Literal_Set_uint64_t")
132 | prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> Bits64 -> PrimIO ()
135 | set : Literal -> List Nat -> Nat -> IO ()
136 | set (MkLiteral lit) idxs value = do
137 | MkIntArray idxsArrayPtr <- mkIntArray idxs
138 | primIO $
prim__literalSetUInt64t lit idxsArrayPtr (cast $
length idxs) (cast value)
140 | %foreign (ffi "Literal_Get_uint64_t")
141 | literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> Bits64
144 | get : Literal -> List Nat -> Nat
145 | get (MkLiteral lit) idxs = unsafePerformIO $
do
146 | MkIntArray idxsArrayPtr <- mkIntArray idxs
147 | pure $
cast $
literalGetUInt64t lit idxsArrayPtr (cast $
length idxs)