0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.Xla.Literal
 18 |
 19 | import Compiler.Xla.Shape
 20 | import Compiler.Xla.ShapeUtil
 21 | import Compiler.Xla.XlaData
 22 | import Compiler.FFI
 23 | import Types
 24 |
 25 | ffi : String -> String
 26 | ffi = libxla "c/xla/literal.h"
 27 |
 28 | namespace Xla
 29 |   public export
 30 |   data Literal : Type where
 31 |     MkLiteral : GCAnyPtr -> Literal
 32 |
 33 | %foreign (ffi "Literal_delete")
 34 | prim__delete : AnyPtr -> PrimIO ()
 35 |
 36 | %foreign (ffi "Literal_new")
 37 | prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr
 38 |
 39 | export
 40 | allocLiteral : HasIO io => Xla.Shape -> io Literal
 41 | allocLiteral (MkShape shape) = do
 42 |   litPtr <- primIO $ prim__allocLiteral shape
 43 |   litPtr <- onCollectAny' litPtr (primIO . prim__delete)
 44 |   pure (MkLiteral litPtr)
 45 |
 46 | export
 47 | %foreign (ffi "Literal_size_bytes")
 48 | prim__literalSizeBytes : GCAnyPtr -> Int
 49 |
 50 | export
 51 | %foreign (ffi "Literal_untyped_data")
 52 | prim__literalUntypedData : GCAnyPtr -> AnyPtr
 53 |
 54 | namespace Bool
 55 |   %foreign (ffi "Literal_Set_bool")
 56 |   prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> Int -> PrimIO ()
 57 |
 58 |   export
 59 |   set : Literal -> List Nat -> Bool -> IO ()
 60 |   set (MkLiteral lit) idxs value = do
 61 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
 62 |     primIO $ prim__literalSetBool lit idxsArrayPtr (cast $ length idxs) (boolToCInt value)
 63 |
 64 |   %foreign (ffi "Literal_Get_bool")
 65 |   literalGetBool : GCAnyPtr -> GCPtr Int -> Int -> Int
 66 |
 67 |   export
 68 |   get : Literal -> List Nat -> Bool
 69 |   get (MkLiteral lit) idxs = unsafePerformIO $ do
 70 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
 71 |     pure $ cIntToBool $ literalGetBool lit idxsArrayPtr (cast $ length idxs)
 72 |
 73 | namespace Double
 74 |   %foreign (ffi "Literal_Set_double")
 75 |   prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Int -> Double -> PrimIO ()
 76 |
 77 |   export
 78 |   set : Literal -> List Nat -> Double -> IO ()
 79 |   set (MkLiteral lit) idxs value = do
 80 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
 81 |     primIO $ prim__literalSetDouble lit idxsArrayPtr (cast $ length idxs) value
 82 |
 83 |   %foreign (ffi "Literal_Get_double")
 84 |   literalGetDouble : GCAnyPtr -> GCPtr Int -> Int -> Double
 85 |
 86 |   export
 87 |   get : Literal -> List Nat -> Double
 88 |   get (MkLiteral lit) idxs = unsafePerformIO $ do
 89 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
 90 |     pure $ literalGetDouble lit idxsArrayPtr (cast $ length idxs)
 91 |
 92 | namespace Int32t
 93 |   %foreign (ffi "Literal_Set_int32_t")
 94 |   prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> Int32 -> PrimIO ()
 95 |
 96 |   export
 97 |   set : Literal -> List Nat -> Int32 -> IO ()
 98 |   set (MkLiteral lit) idxs value = do
 99 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
100 |     primIO $ prim__literalSetInt32t lit idxsArrayPtr (cast $ length idxs) value
101 |
102 |   %foreign (ffi "Literal_Get_int32_t")
103 |   literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int -> Int32
104 |
105 |   export
106 |   get : Literal -> List Nat -> Int32
107 |   get (MkLiteral lit) idxs = unsafePerformIO $ do
108 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
109 |     pure $ cast $ literalGetInt32t lit idxsArrayPtr (cast $ length idxs)
110 |
111 | namespace UInt32t
112 |   %foreign (ffi "Literal_Set_uint32_t")
113 |   prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> Bits32 -> PrimIO ()
114 |
115 |   export
116 |   set : Literal -> List Nat -> Nat -> IO ()
117 |   set (MkLiteral lit) idxs value = do
118 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
119 |     primIO $ prim__literalSetUInt32t lit idxsArrayPtr (cast $ length idxs) (cast value)
120 |
121 |   %foreign (ffi "Literal_Get_uint32_t")
122 |   literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> Bits32
123 |
124 |   export
125 |   get : Literal -> List Nat -> Nat
126 |   get (MkLiteral lit) idxs = unsafePerformIO $ do
127 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
128 |     pure $ cast $ literalGetUInt32t lit idxsArrayPtr (cast $ length idxs)
129 |
130 | namespace UInt64t
131 |   %foreign (ffi "Literal_Set_uint64_t")
132 |   prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> Bits64 -> PrimIO ()
133 |
134 |   export
135 |   set : Literal -> List Nat -> Nat -> IO ()
136 |   set (MkLiteral lit) idxs value = do
137 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
138 |     primIO $ prim__literalSetUInt64t lit idxsArrayPtr (cast $ length idxs) (cast value)
139 |
140 |   %foreign (ffi "Literal_Get_uint64_t")
141 |   literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> Bits64
142 |
143 |   export
144 |   get : Literal -> List Nat -> Nat
145 |   get (MkLiteral lit) idxs = unsafePerformIO $ do
146 |     MkIntArray idxsArrayPtr <- mkIntArray idxs
147 |     pure $ cast $ literalGetUInt64t lit idxsArrayPtr (cast $ length idxs)
148 |