0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.FFI
 18 |
 19 | import public System.FFI
 20 | import Util
 21 |
 22 | public export
 23 | libxla : String -> String -> String
 24 | libxla header fname = "C:\{fname},libc_xla,\{header}"
 25 |
 26 | ffi : String -> String
 27 | ffi = libxla "c/ffi.h"
 28 |
 29 | public export
 30 | data CppString = MkCppString AnyPtr
 31 |
 32 | export
 33 | onCollectAny' : HasIO io => AnyPtr -> (AnyPtr -> IO ()) -> io GCAnyPtr
 34 | onCollectAny' x f = onCollectAny x f
 35 |
 36 | export
 37 | %foreign (ffi "string_new")
 38 | prim__mkString : PrimIO AnyPtr
 39 |
 40 | ||| It is up to the caller to `delete` the string.
 41 | export
 42 | cppString : HasIO io => io CppString
 43 | cppString = MkCppString <$> primIO prim__mkString
 44 |
 45 | export
 46 | %foreign (ffi "string_delete")
 47 | prim__stringDelete : AnyPtr -> PrimIO ()
 48 |
 49 | namespace CppString
 50 |   export
 51 |   delete : HasIO io => CppString -> io ()
 52 |   delete (MkCppString str) = primIO $ prim__stringDelete str
 53 |
 54 | %foreign (ffi "string_c_str")
 55 | prim__stringCStr : AnyPtr -> String
 56 |
 57 | export
 58 | c_str : CppString -> String
 59 | c_str (MkCppString str) = prim__stringCStr str
 60 |
 61 | export
 62 | %foreign (ffi "string_data")
 63 | prim__stringData : AnyPtr -> Ptr Char
 64 |
 65 | export
 66 | %foreign (ffi "string_size")
 67 | prim__stringSize : AnyPtr -> Bits64
 68 |
 69 | export
 70 | %foreign (ffi "idx")
 71 | prim__index : Int -> AnyPtr -> AnyPtr
 72 |
 73 | export
 74 | cIntToBool : Int -> Bool
 75 | cIntToBool 0 = False
 76 | cIntToBool 1 = True
 77 | cIntToBool x =
 78 |   let msg = "Internal error: expected 0 or 1 from XLA C API for boolean conversion, got " ++ show x
 79 |   in (assert_total idris_crash) msg
 80 |
 81 | %foreign (ffi "isnull")
 82 | prim__isNullPtr : AnyPtr -> Int
 83 |
 84 | export
 85 | isNullPtr : AnyPtr -> Bool
 86 | isNullPtr ptr = cIntToBool $ prim__isNullPtr ptr
 87 |
 88 | export
 89 | boolToCInt : Bool -> Int
 90 | boolToCInt True = 1
 91 | boolToCInt False = 0
 92 |
 93 | public export
 94 | data IntArray : Type where
 95 |   MkIntArray : GCPtr Int -> IntArray
 96 |
 97 | %foreign (ffi "sizeof_int")
 98 | sizeofInt : Bits64
 99 |
100 | %foreign (ffi "set_array_int")
101 | prim__setArrayInt : Ptr Int -> Int -> Int -> PrimIO ()
102 |
103 | export
104 | mkIntArray : (HasIO io, Cast a Int) => List a -> io IntArray
105 | mkIntArray xs = do
106 |   ptr <- malloc (cast (length xs) * cast sizeofInt)
107 |   let ptr = prim__castPtr ptr
108 |   traverse_ (\(idx, x) => primIO $ prim__setArrayInt ptr (cast idx) (cast x)) (enumerate xs)
109 |   ptr <- onCollect ptr (free . prim__forgetPtr)
110 |   pure (MkIntArray ptr)
111 |
112 | public export
113 | data Int64Array : Type where
114 |   MkInt64Array : GCPtr Int64 -> Int64Array
115 |
116 | %foreign (ffi "sizeof_int64_t")
117 | sizeofInt64 : Bits64
118 |
119 | %foreign (ffi "set_array_int64_t")
120 | prim__setArrayInt64 : Ptr Int64 -> Bits64 -> Int64 -> PrimIO ()
121 |
122 | export
123 | mkInt64Array : HasIO io => List Int64 -> io Int64Array
124 | mkInt64Array xs = do
125 |   ptr <- malloc (cast (length xs) * cast sizeofInt64)
126 |   let ptr = prim__castPtr ptr
127 |   traverse_ (\(idx, x) => primIO $ prim__setArrayInt64 ptr (cast idx) (cast x)) (enumerate xs)
128 |   ptr <- onCollect ptr (free . prim__forgetPtr)
129 |   pure (MkInt64Array ptr)
130 |
131 | export
132 | %foreign (ffi "sizeof_ptr")
133 | sizeofPtr : Bits64
134 |
135 | export
136 | %foreign (ffi "set_array_ptr")
137 | prim__setArrayPtr : AnyPtr -> Int -> AnyPtr -> PrimIO ()
138 |