0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.Xla.Literal
 18 |
 19 | import Compiler.Xla.Shape
 20 | import Compiler.Xla.ShapeUtil
 21 | import Compiler.Xla.XlaData
 22 | import Compiler.FFI
 23 | import Types
 24 |
 25 | ffi : String -> String
 26 | ffi = libxla "c/xla/literal.h"
 27 |
 28 | namespace Xla
 29 |   public export
 30 |   data Literal : Type where
 31 |     MkLiteral : GCAnyPtr -> Literal
 32 |
 33 | %foreign (ffi "Literal_delete")
 34 | prim__delete : AnyPtr -> PrimIO ()
 35 |
 36 | %foreign (ffi "Literal_new")
 37 | prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr
 38 |
 39 | export
 40 | allocLiteral : HasIO io => Xla.Shape -> io Literal
 41 | allocLiteral (MkShape shape) = do
 42 |   litPtr <- primIO $ prim__allocLiteral shape
 43 |   litPtr <- onCollectAny' litPtr (primIO . prim__delete)
 44 |   pure (MkLiteral litPtr)
 45 |
 46 | export
 47 | %foreign (ffi "Literal_size_bytes")
 48 | prim__literalSizeBytes : GCAnyPtr -> Int64
 49 |
 50 | export
 51 | %foreign (ffi "Literal_untyped_data")
 52 | prim__literalUntypedData : GCAnyPtr -> AnyPtr
 53 |
 54 | namespace Bool
 55 |   %foreign (ffi "Literal_Set_bool")
 56 |   prim__literalSetBool : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int -> PrimIO ()
 57 |
 58 |   export
 59 |   set : Literal -> Int64Array -> Bool -> IO ()
 60 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
 61 |     primIO $ prim__literalSetBool lit idxs idxsLen (boolToCInt value)
 62 |
 63 |   %foreign (ffi "Literal_Get_bool")
 64 |   prim__literalGetBool : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int
 65 |
 66 |   export
 67 |   get : Literal -> Int64Array -> Bool
 68 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) =
 69 |     cIntToBool $ prim__literalGetBool lit idxs idxsLen
 70 |
 71 | namespace Double
 72 |   %foreign (ffi "Literal_Set_double")
 73 |   prim__literalSetDouble : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Double -> PrimIO ()
 74 |
 75 |   export
 76 |   set : Literal -> Int64Array -> Double -> IO ()
 77 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
 78 |     primIO $ prim__literalSetDouble lit idxs idxsLen value
 79 |
 80 |   %foreign (ffi "Literal_Get_double")
 81 |   prim__literalGetDouble : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Double
 82 |
 83 |   export
 84 |   get : Literal -> Int64Array -> Double
 85 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetDouble lit idxs idxsLen
 86 |
 87 | namespace Int32t
 88 |   %foreign (ffi "Literal_Set_int32_t")
 89 |   prim__literalSetInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int32 -> PrimIO ()
 90 |
 91 |   export
 92 |   set : Literal -> Int64Array -> Int32 -> IO ()
 93 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
 94 |     primIO $ prim__literalSetInt32t lit idxs idxsLen value
 95 |
 96 |   %foreign (ffi "Literal_Get_int32_t")
 97 |   prim__literalGetInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int32
 98 |
 99 |   export
100 |   get : Literal -> Int64Array -> Int32
101 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetInt32t lit idxs idxsLen
102 |
103 | namespace Int64t
104 |   %foreign (ffi "Literal_Set_int64_t")
105 |   prim__literalSetInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int64 -> PrimIO ()
106 |
107 |   export
108 |   set : Literal -> Int64Array -> Int64 -> IO ()
109 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
110 |     primIO $ prim__literalSetInt64t lit idxs idxsLen value
111 |
112 |   %foreign (ffi "Literal_Get_int64_t")
113 |   prim__literalGetInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Int64
114 |
115 |   export
116 |   get : Literal -> Int64Array -> Int64
117 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetInt64t lit idxs idxsLen
118 |
119 | namespace UInt32t
120 |   %foreign (ffi "Literal_Set_uint32_t")
121 |   prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits32 -> PrimIO ()
122 |
123 |   export
124 |   set : Literal -> Int64Array -> Bits32 -> IO ()
125 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
126 |     primIO $ prim__literalSetUInt32t lit idxs idxsLen value
127 |
128 |   %foreign (ffi "Literal_Get_uint32_t")
129 |   prim__literalGetUInt32t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits32
130 |
131 |   export
132 |   get : Literal -> Int64Array -> Bits32
133 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetUInt32t lit idxs idxsLen
134 |
135 | namespace UInt64t
136 |   %foreign (ffi "Literal_Set_uint64_t")
137 |   prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits64 -> PrimIO ()
138 |
139 |   export
140 |   set : Literal -> Int64Array -> Bits64 -> IO ()
141 |   set (MkLiteral lit) (MkInt64Array idxs idxsLen) value =
142 |     primIO $ prim__literalSetUInt64t lit idxs idxsLen value
143 |
144 |   %foreign (ffi "Literal_Get_uint64_t")
145 |   prim__literalGetUInt64t : GCAnyPtr -> GCPtr Int64 -> Bits64 -> Bits64
146 |
147 |   export
148 |   get : Literal -> Int64Array -> Bits64
149 |   get (MkLiteral lit) (MkInt64Array idxs idxsLen) = prim__literalGetUInt64t lit idxs idxsLen
150 |