0 | {--
  1 | Copyright (C) 2024  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use, and use by plugin developers.
 17 | |||
 18 | ||| The Idris API for PJRT.
 19 | module Compiler.Xla.PJRT.C.PjrtCApi
 20 |
 21 | import public Control.Monad.Either
 22 | import Derive.Prelude
 23 | import Language.Reflection
 24 |
 25 | import Compiler.FFI
 26 | import Compiler.Xla.Literal
 27 | import Types
 28 | import Util
 29 |
 30 | %language ElabReflection
 31 |
 32 | ffi : String -> String
 33 | ffi = libxla "c/xla/pjrt/c/pjrt_c_api.h"
 34 |
 35 | ||| For use by plugin developers.
 36 | |||
 37 | ||| A minimal wrapper round a C `PJRT_Api` struct pointer. The memory should be owned by the
 38 | ||| code producing the pointer.
 39 | public export
 40 | data PjrtApi = MkPjrtApi AnyPtr
 41 |
 42 | ||| The cause of a `PjrtError`.
 43 | public export
 44 | data PjrtErrorCode =
 45 |     PJRT_Error_Code_CANCELLED
 46 |   | PJRT_Error_Code_UNKNOWN
 47 |   | PJRT_Error_Code_INVALID_ARGUMENT
 48 |   | PJRT_Error_Code_DEADLINE_EXCEEDED
 49 |   | PJRT_Error_Code_NOT_FOUND
 50 |   | PJRT_Error_Code_ALREADY_EXISTS
 51 |   | PJRT_Error_Code_PERMISSION_DENIED
 52 |   | PJRT_Error_Code_RESOURCE_EXHAUSTED
 53 |   | PJRT_Error_Code_FAILED_PRECONDITION
 54 |   | PJRT_Error_Code_ABORTED
 55 |   | PJRT_Error_Code_OUT_OF_RANGE
 56 |   | PJRT_Error_Code_UNIMPLEMENTED
 57 |   | PJRT_Error_Code_INTERNAL
 58 |   | PJRT_Error_Code_UNAVAILABLE
 59 |   | PJRT_Error_Code_DATA_LOSS
 60 |   | PJRT_Error_Code_UNAUTHENTICATED
 61 |
 62 | %runElab derive "PjrtErrorCode" [Show]
 63 |
 64 | ||| Indicates an error in the PJRT C layer, either due to internal errors or user error.
 65 | public export
 66 | record PjrtError where
 67 |   constructor MkPjrtError
 68 |
 69 |   ||| The error message.
 70 |   message : String
 71 |
 72 |   ||| The error cause code, if one exists.
 73 |   code : Maybe PjrtErrorCode
 74 |
 75 | export
 76 | Show PjrtError where
 77 |   show e =
 78 |     let code = case e.code of
 79 |           Nothing => "unknown"
 80 |           Just c => show c
 81 |      in "PjrtError (error code \{code})\n\{e.message}"
 82 |
 83 | %foreign (ffi "PJRT_Error_Destroy_Args_delete")
 84 | prim__deletePjrtErrorDestroyArgs : AnyPtr -> PrimIO ()
 85 |
 86 | %foreign (ffi "PJRT_Error_Destroy_Args_new")
 87 | prim__mkPjrtErrorDestroyArgs : AnyPtr -> PrimIO AnyPtr
 88 |
 89 | %foreign (ffi "pjrt_error_destroy")
 90 | prim__pjrtErrorDestroy : AnyPtr -> AnyPtr -> PrimIO ()
 91 |
 92 | destroyPjrtError : HasIO io => AnyPtr -> AnyPtr -> io ()
 93 | destroyPjrtError api err = do
 94 |   args <- primIO $ prim__mkPjrtErrorDestroyArgs err
 95 |   primIO $ prim__pjrtErrorDestroy api args
 96 |   primIO $ prim__deletePjrtErrorDestroyArgs args
 97 |
 98 | %foreign (ffi "PJRT_Error_Message_Args_delete")
 99 | prim__deletePjrtErrorMessageArgs : AnyPtr -> PrimIO ()
100 |
101 | %foreign (ffi "PJRT_Error_Message_Args_new")
102 | prim__mkPjrtErrorMessageArgs : AnyPtr -> PrimIO AnyPtr
103 |
104 | %foreign (ffi "PJRT_Error_Message_Args_message")
105 | prim__pjrtErrorMessageArgsMessage : AnyPtr -> PrimIO String
106 |
107 | %foreign (ffi "pjrt_error_message")
108 | prim__pjrtErrorMessage : AnyPtr -> AnyPtr -> PrimIO ()
109 |
110 | pjrtErrorMessage : HasIO io => AnyPtr -> AnyPtr -> io String
111 | pjrtErrorMessage api err = do
112 |   args <- primIO $ prim__mkPjrtErrorMessageArgs err
113 |   primIO $ prim__pjrtErrorMessage api args
114 |   msg <- primIO $ prim__pjrtErrorMessageArgsMessage args
115 |   primIO $ prim__deletePjrtErrorMessageArgs args
116 |   pure msg
117 |
118 | %foreign (ffi "PJRT_Error_GetCode_Args_delete")
119 | prim__deletePjrtErrorGetCodeArgs : AnyPtr -> PrimIO ()
120 |
121 | %foreign (ffi "PJRT_Error_GetCode_Args_new")
122 | prim__mkPjrtErrorGetCodeArgs : AnyPtr -> PrimIO AnyPtr
123 |
124 | %foreign (ffi "PJRT_Error_GetCode_Args_code")
125 | prim__pjrtErrorGetCodeArgsCode : AnyPtr -> Int
126 |
127 | %foreign (ffi "pjrt_error_getcode")
128 | prim__pjrtErrorGetCode : AnyPtr -> AnyPtr -> PrimIO AnyPtr
129 |
130 | pjrtErrorCodeFromCInt : Int -> PjrtErrorCode
131 | pjrtErrorCodeFromCInt = \case
132 |   1  => PJRT_Error_Code_CANCELLED
133 |   2  => PJRT_Error_Code_UNKNOWN
134 |   3  => PJRT_Error_Code_INVALID_ARGUMENT
135 |   4  => PJRT_Error_Code_DEADLINE_EXCEEDED
136 |   5  => PJRT_Error_Code_NOT_FOUND
137 |   6  => PJRT_Error_Code_ALREADY_EXISTS
138 |   7  => PJRT_Error_Code_PERMISSION_DENIED
139 |   8  => PJRT_Error_Code_RESOURCE_EXHAUSTED
140 |   9  => PJRT_Error_Code_FAILED_PRECONDITION
141 |   10 => PJRT_Error_Code_ABORTED
142 |   11 => PJRT_Error_Code_OUT_OF_RANGE
143 |   12 => PJRT_Error_Code_UNIMPLEMENTED
144 |   13 => PJRT_Error_Code_INTERNAL
145 |   14 => PJRT_Error_Code_UNAVAILABLE
146 |   15 => PJRT_Error_Code_DATA_LOSS
147 |   16 => PJRT_Error_Code_UNAUTHENTICATED
148 |   n  => assert_total $ idris_crash
149 |     "Unexpected PJRT_Error_Code value received through FFI from XLA: \{show n}"
150 |
151 | ||| A `Pjrt a` produces an `a` or an error from the PJRT layer.
152 | public export 0
153 | Pjrt : Type -> Type
154 | Pjrt = EitherT PjrtError IO
155 |
156 | try : AnyPtr -> AnyPtr -> a -> Pjrt a
157 | try api err onOk = if (isNullPtr err) then right onOk else do
158 |   msg <- pjrtErrorMessage api err
159 |   args <- primIO $ prim__mkPjrtErrorGetCodeArgs err
160 |   getCodeErr <- primIO $ prim__pjrtErrorGetCode api args
161 |   code <- if (isNullPtr getCodeErr) then pure Nothing else do
162 |     let code = prim__pjrtErrorGetCodeArgsCode args
163 |     destroyPjrtError api getCodeErr
164 |     pure $ Just code
165 |   primIO $ prim__deletePjrtErrorGetCodeArgs args
166 |   destroyPjrtError api err
167 |   left $ MkPjrtError msg $ map pjrtErrorCodeFromCInt code
168 |
169 | ||| For internal spidr use only.
170 | export
171 | data PjrtEvent = MkPjrtEvent AnyPtr
172 |
173 | %foreign (ffi "PJRT_Event_Destroy_Args_delete")
174 | prim__deletePjrtEventDestroyArgs : AnyPtr -> PrimIO ()
175 |
176 | %foreign (ffi "PJRT_Event_Destroy_Args_new")
177 | prim__mkPjrtEventDestroyArgs : AnyPtr -> PrimIO AnyPtr
178 |
179 | %foreign (ffi "pjrt_event_destroy")
180 | prim__pjrtEventDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
181 |
182 | %foreign (ffi "PJRT_Event_Await_Args_delete")
183 | prim__deletePjrtEventAwaitArgs : AnyPtr -> PrimIO ()
184 |
185 | %foreign (ffi "PJRT_Event_Await_Args_new")
186 | prim__mkPjrtEventAwaitArgs : AnyPtr -> PrimIO AnyPtr
187 |
188 | %foreign (ffi "pjrt_event_await")
189 | prim__pjrtEventAwait : AnyPtr -> AnyPtr -> PrimIO AnyPtr
190 |
191 | ||| For internal spidr use only.
192 | export
193 | pjrtEventAwait : PjrtApi -> PjrtEvent -> Pjrt ()
194 | pjrtEventAwait (MkPjrtApi api) (MkPjrtEvent event) = do
195 |   args <- primIO $ prim__mkPjrtEventAwaitArgs event
196 |   err <- primIO $ prim__pjrtEventAwait api args
197 |   primIO $ prim__deletePjrtEventAwaitArgs args
198 |   try api err ()
199 |
200 | ||| For use by plugin developers.
201 | export
202 | data PjrtClient = MkPjrtClient GCAnyPtr
203 |
204 | %foreign (ffi "PJRT_Client_Create_Args_delete")
205 | prim__deletePjrtClientCreateArgs : AnyPtr -> PrimIO ()
206 |
207 | %foreign (ffi "PJRT_Client_Create_Args_new")
208 | prim__mkPjrtClientCreateArgs : PrimIO AnyPtr
209 |
210 | %foreign (ffi "PJRT_Client_Create_Args_client")
211 | prim__pjrtClientCreateArgsClient : AnyPtr -> AnyPtr
212 |
213 | %foreign (ffi "pjrt_client_create")
214 | prim__pjrtClientCreate : AnyPtr -> AnyPtr -> PrimIO AnyPtr
215 |
216 | %foreign (ffi "PJRT_Client_Destroy_Args_delete")
217 | prim__deletePjrtClientDestroyArgs : AnyPtr -> PrimIO ()
218 |
219 | %foreign (ffi "PJRT_Client_Destroy_Args_new")
220 | prim__mkPjrtClientDestroyArgs : AnyPtr -> PrimIO AnyPtr
221 |
222 | %foreign (ffi "pjrt_client_destroy")
223 | prim__pjrtClientDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
224 |
225 | handleErrOnDestroy : HasIO io => AnyPtr -> AnyPtr -> String -> io ()
226 | handleErrOnDestroy api err target = unless (isNullPtr err) $ do
227 |   msg <- pjrtErrorMessage api err
228 |   args <- primIO $ prim__mkPjrtErrorGetCodeArgs err
229 |   getCodeErr <- primIO $ prim__pjrtErrorGetCode api args
230 |   if (isNullPtr getCodeErr) then do
231 |       let code = pjrtErrorCodeFromCInt $ prim__pjrtErrorGetCodeArgsCode args
232 |       printLn "WARN: Failed to destroy \{target} with error code \{show code}; message: \{msg}"
233 |     else do
234 |       printLn "WARN: Failed to fetch error code"
235 |       printLn "WARN: Failed to destroy \{target} with unknown error code; message: \{msg}"
236 |       destroyPjrtError api getCodeErr
237 |   primIO $ prim__deletePjrtErrorGetCodeArgs args
238 |   destroyPjrtError api err
239 |
240 | ||| For use by plugin developers.
241 | |||
242 | ||| Create a `PjrtClient`.
243 | export
244 | pjrtClientCreate : PjrtApi -> Pjrt PjrtClient
245 | pjrtClientCreate (MkPjrtApi api) = do
246 |   args <- primIO prim__mkPjrtClientCreateArgs
247 |   err <- primIO $ prim__pjrtClientCreate api args
248 |   let client = prim__pjrtClientCreateArgsClient args
249 |   primIO $ prim__deletePjrtClientCreateArgs args
250 |   try api err =<< do
251 |     client <- onCollectAny' client destroy
252 |     pure $ MkPjrtClient client
253 |
254 |     where
255 |
256 |     destroy : AnyPtr -> IO ()
257 |     destroy client = do
258 |       args <- primIO $ prim__mkPjrtClientDestroyArgs client
259 |       err <- primIO $ prim__pjrtClientDestroy api args
260 |       primIO $ prim__deletePjrtClientDestroyArgs args
261 |       handleErrOnDestroy api err "PJRT_Client"
262 |
263 | ||| For internal spidr use only.
264 | export
265 | data PjrtProgram = MkPjrtProgram AnyPtr
266 |
267 | %foreign (ffi "PJRT_Program_delete")
268 | prim__deletePjrtProgram : AnyPtr -> PrimIO ()
269 |
270 | namespace PjrtProgram
271 |   export
272 |   delete : HasIO io => PjrtProgram -> io ()
273 |   delete (MkPjrtProgram p) = primIO $ prim__deletePjrtProgram p
274 |
275 | %foreign (ffi "PJRT_Program_new")
276 | prim__mkPjrtProgram : Ptr Char -> Bits64 -> PrimIO AnyPtr
277 |
278 | ||| For internal spidr use only.
279 | |||
280 | ||| The `CppString` must live as long as the `PjrtProgram`.
281 | ||| It is up to the caller to deallocate the `PjrtProgram`.
282 | export
283 | mkPjrtProgram : HasIO io => CppString -> io PjrtProgram
284 | mkPjrtProgram (MkCppString code) = MkPjrtProgram <$> (
285 |     primIO $ prim__mkPjrtProgram (prim__stringData code) (prim__stringSize code)
286 |   )
287 |
288 | %foreign (ffi "PJRT_Client_Compile_Args_delete")
289 | prim__deletePjrtClientCompileArgs : AnyPtr -> PrimIO ()
290 |
291 | %foreign (ffi "PJRT_Client_Compile_Args_new")
292 | prim__mkPjrtClientCompileArgs : GCAnyPtr -> AnyPtr -> Ptr Char -> Bits64 -> PrimIO AnyPtr
293 |
294 | %foreign (ffi "PJRT_Client_Compile_Args_executable")
295 | prim__pjrtClientCompileArgsExecutable : AnyPtr -> AnyPtr
296 |
297 | %foreign (ffi "pjrt_client_compile")
298 | prim__pjrtClientCompile : AnyPtr -> AnyPtr -> PrimIO AnyPtr
299 |
300 | %foreign (ffi "PJRT_LoadedExecutable_Destroy_Args_delete")
301 | prim__deletePjrtLoadedExecutableDestroyArgs : AnyPtr -> PrimIO ()
302 |
303 | %foreign (ffi "PJRT_LoadedExecutable_Destroy_Args_new")
304 | prim__mkPjrtLoadedExecutableDestroyArgs : AnyPtr -> PrimIO AnyPtr
305 |
306 | %foreign (ffi "pjrt_loadedexecutable_destroy")
307 | prim__pjrtLoadedExecutableDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
308 |
309 | ||| For internal spidr use only.
310 | export
311 | data PjrtLoadedExecutable = MkPjrtLoadedExecutable AnyPtr
312 |
313 | ||| For internal spidr use only.
314 | export
315 | pjrtLoadedExecutableDestroy : HasIO io => PjrtApi -> PjrtLoadedExecutable -> io ()
316 | pjrtLoadedExecutableDestroy (MkPjrtApi api) (MkPjrtLoadedExecutable executable) = do
317 |   args <- primIO $ prim__mkPjrtLoadedExecutableDestroyArgs executable
318 |   err <- primIO $ prim__pjrtLoadedExecutableDestroy api args
319 |   primIO $ prim__deletePjrtLoadedExecutableDestroyArgs args
320 |   handleErrOnDestroy api err "PJRT_LoadedExecutable"
321 |
322 | ||| For internal spidr use only.
323 | |||
324 | ||| It is up to the caller to deallocate the `PjrtLoadedExecutable`.
325 | export
326 | pjrtClientCompile : PjrtApi -> PjrtClient -> PjrtProgram -> CppString -> Pjrt PjrtLoadedExecutable
327 | pjrtClientCompile
328 |   (MkPjrtApi api) (MkPjrtClient client) (MkPjrtProgram program) (MkCppString options) = do
329 |     args <- primIO $ prim__mkPjrtClientCompileArgs
330 |       client program (prim__stringData options) (prim__stringSize options)
331 |     err <- primIO $ prim__pjrtClientCompile api args
332 |     let executable = prim__pjrtClientCompileArgsExecutable args
333 |     primIO $ prim__deletePjrtClientCompileArgs args
334 |     try api err $ MkPjrtLoadedExecutable executable
335 |
336 | %foreign (ffi "PJRT_ExecuteOptions_delete")
337 | prim__deletePjrtExecuteOptions : AnyPtr -> PrimIO ()
338 |
339 | %foreign (ffi "PJRT_ExecuteOptions_new")
340 | prim__mkPjrtExecuteOptions : PrimIO AnyPtr
341 |
342 | %foreign (ffi "PJRT_LoadedExecutable_Execute_Args_delete")
343 | prim__deletePjrtLoadedExecutableExecuteArgs : AnyPtr -> PrimIO ()
344 |
345 | %foreign (ffi "PJRT_LoadedExecutable_Execute_Args_new")
346 | prim__mkPjrtLoadedExecutableExecuteArgs : AnyPtr -> AnyPtr -> AnyPtr -> PrimIO AnyPtr
347 |
348 | %foreign (ffi "pjrt_loadedexecutable_execute")
349 | prim__pjrtLoadedExecutableExecute : AnyPtr -> AnyPtr -> PrimIO AnyPtr
350 |
351 | %foreign (ffi "PJRT_Buffer_Destroy_Args_delete")
352 | prim__deletePjrtBufferDestroyArgs : AnyPtr -> PrimIO ()
353 |
354 | %foreign (ffi "PJRT_Buffer_Destroy_Args_new")
355 | prim__mkPjrtBufferDestroyArgs : AnyPtr -> PrimIO AnyPtr
356 |
357 | %foreign (ffi "pjrt_buffer_destroy")
358 | prim__pjrtBufferDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
359 |
360 | ||| For internal spidr use only.
361 | export
362 | data PjrtBuffer = MkPjrtBuffer AnyPtr
363 |
364 | ||| For internal spidr use only.
365 | export
366 | pjrtBufferDestroy : HasIO io => PjrtApi -> PjrtBuffer -> io ()
367 | pjrtBufferDestroy (MkPjrtApi api) (MkPjrtBuffer buffer) = do
368 |   args <- primIO $ prim__mkPjrtBufferDestroyArgs buffer
369 |   err <- primIO $ prim__pjrtBufferDestroy api args
370 |   primIO $ prim__deletePjrtBufferDestroyArgs args
371 |   handleErrOnDestroy api err "PJRT_Buffer"
372 |
373 | ||| For internal spidr use only.
374 | |||
375 | ||| It is up to the caller to deallocate the `PjrtBuffer`s.
376 | export
377 | pjrtLoadedExecutableExecute :
378 |   PjrtApi -> PjrtLoadedExecutable -> (outputs : Nat) -> Pjrt (Vect outputs PjrtBuffer)
379 | pjrtLoadedExecutableExecute (MkPjrtApi api) (MkPjrtLoadedExecutable executable) outputs = do
380 |   outputListsInner <- malloc (cast outputs * cast sizeofPtr)
381 |   outputLists <- malloc $ cast sizeofPtr
382 |   primIO $ prim__setArrayPtr outputLists 0 outputListsInner
383 |   options <- primIO prim__mkPjrtExecuteOptions
384 |   args <- primIO $ prim__mkPjrtLoadedExecutableExecuteArgs executable options outputLists
385 |   err <- primIO $ prim__pjrtLoadedExecutableExecute api args
386 |   primIO $ prim__deletePjrtLoadedExecutableExecuteArgs args
387 |   primIO $ prim__deletePjrtExecuteOptions options
388 |   let buffers = map (\o => MkPjrtBuffer $ prim__index (cast o) outputListsInner) (range outputs)
389 |   free outputLists
390 |   free outputListsInner
391 |   try api err buffers
392 |
393 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_delete")
394 | prim__deletePjrtBufferToHostBufferArgs : AnyPtr -> PrimIO ()
395 |
396 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_new")
397 | prim__mkPjrtBufferToHostBufferArgs : AnyPtr -> AnyPtr -> Int -> PrimIO AnyPtr
398 |
399 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_event")
400 | prim__pjrtBufferToHostBufferArgsEvent : AnyPtr -> AnyPtr
401 |
402 | %foreign (ffi "pjrt_buffer_tohostbuffer")
403 | prim__pjrtBufferToHostBuffer : AnyPtr -> AnyPtr -> PrimIO AnyPtr
404 |
405 | ||| For internal spidr use only.
406 | export
407 | pjrtEventDestroy : HasIO io => PjrtApi -> PjrtEvent -> io ()
408 | pjrtEventDestroy (MkPjrtApi api) (MkPjrtEvent event) = do
409 |   args <- primIO $ prim__mkPjrtEventDestroyArgs event
410 |   err <- primIO $ prim__pjrtEventDestroy api args
411 |   primIO $ prim__deletePjrtEventDestroyArgs args
412 |   handleErrOnDestroy api err "PJRT_Event"
413 |
414 | ||| For internal spidr use only.
415 | |||
416 | ||| It is up to the caller to deallocate the `PjrtEvent`.
417 | export
418 | pjrtBufferToHostBuffer : PjrtApi -> PjrtBuffer -> Literal -> Pjrt PjrtEvent
419 | pjrtBufferToHostBuffer (MkPjrtApi api) (MkPjrtBuffer buffer) (MkLiteral literal) = do
420 |   let untypedData = prim__literalUntypedData literal
421 |       sizeBytes = prim__literalSizeBytes literal
422 |   args <- primIO $ prim__mkPjrtBufferToHostBufferArgs buffer untypedData sizeBytes
423 |   err <- primIO $ prim__pjrtBufferToHostBuffer api args
424 |   let event = prim__pjrtBufferToHostBufferArgsEvent args
425 |   primIO $ prim__deletePjrtBufferToHostBufferArgs args
426 |   try api err $ MkPjrtEvent event
427 |