19 | module Compiler.Xla.PJRT.C.PjrtCApi
21 | import public Control.Monad.Either
22 | import Derive.Prelude
23 | import Language.Reflection
26 | import Compiler.Xla.Literal
30 | %language ElabReflection
32 | ffi : String -> String
33 | ffi = libxla "c/xla/pjrt/c/pjrt_c_api.h"
40 | data PjrtApi = MkPjrtApi AnyPtr
44 | data PjrtErrorCode =
45 | PJRT_Error_Code_CANCELLED
46 | | PJRT_Error_Code_UNKNOWN
47 | | PJRT_Error_Code_INVALID_ARGUMENT
48 | | PJRT_Error_Code_DEADLINE_EXCEEDED
49 | | PJRT_Error_Code_NOT_FOUND
50 | | PJRT_Error_Code_ALREADY_EXISTS
51 | | PJRT_Error_Code_PERMISSION_DENIED
52 | | PJRT_Error_Code_RESOURCE_EXHAUSTED
53 | | PJRT_Error_Code_FAILED_PRECONDITION
54 | | PJRT_Error_Code_ABORTED
55 | | PJRT_Error_Code_OUT_OF_RANGE
56 | | PJRT_Error_Code_UNIMPLEMENTED
57 | | PJRT_Error_Code_INTERNAL
58 | | PJRT_Error_Code_UNAVAILABLE
59 | | PJRT_Error_Code_DATA_LOSS
60 | | PJRT_Error_Code_UNAUTHENTICATED
62 | %runElab derive "PjrtErrorCode" [Show]
66 | record PjrtError where
67 | constructor MkPjrtError
73 | code : Maybe PjrtErrorCode
76 | Show PjrtError where
78 | let code = case e.code of
79 | Nothing => "unknown"
81 | in "PjrtError (error code \{code})\n\{e.message}"
83 | %foreign (ffi "PJRT_Error_Destroy_Args_delete")
84 | prim__deletePjrtErrorDestroyArgs : AnyPtr -> PrimIO ()
86 | %foreign (ffi "PJRT_Error_Destroy_Args_new")
87 | prim__mkPjrtErrorDestroyArgs : AnyPtr -> PrimIO AnyPtr
89 | %foreign (ffi "pjrt_error_destroy")
90 | prim__pjrtErrorDestroy : AnyPtr -> AnyPtr -> PrimIO ()
92 | destroyPjrtError : HasIO io => AnyPtr -> AnyPtr -> io ()
93 | destroyPjrtError api err = do
94 | args <- primIO $
prim__mkPjrtErrorDestroyArgs err
95 | primIO $
prim__pjrtErrorDestroy api args
96 | primIO $
prim__deletePjrtErrorDestroyArgs args
98 | %foreign (ffi "PJRT_Error_Message_Args_delete")
99 | prim__deletePjrtErrorMessageArgs : AnyPtr -> PrimIO ()
101 | %foreign (ffi "PJRT_Error_Message_Args_new")
102 | prim__mkPjrtErrorMessageArgs : AnyPtr -> PrimIO AnyPtr
104 | %foreign (ffi "PJRT_Error_Message_Args_message")
105 | prim__pjrtErrorMessageArgsMessage : AnyPtr -> PrimIO String
107 | %foreign (ffi "pjrt_error_message")
108 | prim__pjrtErrorMessage : AnyPtr -> AnyPtr -> PrimIO ()
110 | pjrtErrorMessage : HasIO io => AnyPtr -> AnyPtr -> io String
111 | pjrtErrorMessage api err = do
112 | args <- primIO $
prim__mkPjrtErrorMessageArgs err
113 | primIO $
prim__pjrtErrorMessage api args
114 | msg <- primIO $
prim__pjrtErrorMessageArgsMessage args
115 | primIO $
prim__deletePjrtErrorMessageArgs args
118 | %foreign (ffi "PJRT_Error_GetCode_Args_delete")
119 | prim__deletePjrtErrorGetCodeArgs : AnyPtr -> PrimIO ()
121 | %foreign (ffi "PJRT_Error_GetCode_Args_new")
122 | prim__mkPjrtErrorGetCodeArgs : AnyPtr -> PrimIO AnyPtr
124 | %foreign (ffi "PJRT_Error_GetCode_Args_code")
125 | prim__pjrtErrorGetCodeArgsCode : AnyPtr -> Int
127 | %foreign (ffi "pjrt_error_getcode")
128 | prim__pjrtErrorGetCode : AnyPtr -> AnyPtr -> PrimIO AnyPtr
130 | pjrtErrorCodeFromCInt : Int -> PjrtErrorCode
131 | pjrtErrorCodeFromCInt = \case
132 | 1 => PJRT_Error_Code_CANCELLED
133 | 2 => PJRT_Error_Code_UNKNOWN
134 | 3 => PJRT_Error_Code_INVALID_ARGUMENT
135 | 4 => PJRT_Error_Code_DEADLINE_EXCEEDED
136 | 5 => PJRT_Error_Code_NOT_FOUND
137 | 6 => PJRT_Error_Code_ALREADY_EXISTS
138 | 7 => PJRT_Error_Code_PERMISSION_DENIED
139 | 8 => PJRT_Error_Code_RESOURCE_EXHAUSTED
140 | 9 => PJRT_Error_Code_FAILED_PRECONDITION
141 | 10 => PJRT_Error_Code_ABORTED
142 | 11 => PJRT_Error_Code_OUT_OF_RANGE
143 | 12 => PJRT_Error_Code_UNIMPLEMENTED
144 | 13 => PJRT_Error_Code_INTERNAL
145 | 14 => PJRT_Error_Code_UNAVAILABLE
146 | 15 => PJRT_Error_Code_DATA_LOSS
147 | 16 => PJRT_Error_Code_UNAUTHENTICATED
148 | n => assert_total $
idris_crash
149 | "Unexpected PJRT_Error_Code value received through FFI from XLA: \{show n}"
153 | Pjrt : Type -> Type
154 | Pjrt = EitherT PjrtError IO
156 | try : AnyPtr -> AnyPtr -> a -> Pjrt a
157 | try api err onOk = if (isNullPtr err) then right onOk else do
158 | msg <- pjrtErrorMessage api err
159 | args <- primIO $
prim__mkPjrtErrorGetCodeArgs err
160 | getCodeErr <- primIO $
prim__pjrtErrorGetCode api args
161 | code <- if (isNullPtr getCodeErr) then pure Nothing else do
162 | let code = prim__pjrtErrorGetCodeArgsCode args
163 | destroyPjrtError api getCodeErr
165 | primIO $
prim__deletePjrtErrorGetCodeArgs args
166 | destroyPjrtError api err
167 | left $
MkPjrtError msg $
map pjrtErrorCodeFromCInt code
171 | data PjrtEvent = MkPjrtEvent AnyPtr
173 | %foreign (ffi "PJRT_Event_Destroy_Args_delete")
174 | prim__deletePjrtEventDestroyArgs : AnyPtr -> PrimIO ()
176 | %foreign (ffi "PJRT_Event_Destroy_Args_new")
177 | prim__mkPjrtEventDestroyArgs : AnyPtr -> PrimIO AnyPtr
179 | %foreign (ffi "pjrt_event_destroy")
180 | prim__pjrtEventDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
182 | %foreign (ffi "PJRT_Event_Await_Args_delete")
183 | prim__deletePjrtEventAwaitArgs : AnyPtr -> PrimIO ()
185 | %foreign (ffi "PJRT_Event_Await_Args_new")
186 | prim__mkPjrtEventAwaitArgs : AnyPtr -> PrimIO AnyPtr
188 | %foreign (ffi "pjrt_event_await")
189 | prim__pjrtEventAwait : AnyPtr -> AnyPtr -> PrimIO AnyPtr
193 | pjrtEventAwait : PjrtApi -> PjrtEvent -> Pjrt ()
194 | pjrtEventAwait (MkPjrtApi api) (MkPjrtEvent event) = do
195 | args <- primIO $
prim__mkPjrtEventAwaitArgs event
196 | err <- primIO $
prim__pjrtEventAwait api args
197 | primIO $
prim__deletePjrtEventAwaitArgs args
202 | data PjrtClient = MkPjrtClient GCAnyPtr
204 | %foreign (ffi "PJRT_Client_Create_Args_delete")
205 | prim__deletePjrtClientCreateArgs : AnyPtr -> PrimIO ()
207 | %foreign (ffi "PJRT_Client_Create_Args_new")
208 | prim__mkPjrtClientCreateArgs : PrimIO AnyPtr
210 | %foreign (ffi "PJRT_Client_Create_Args_client")
211 | prim__pjrtClientCreateArgsClient : AnyPtr -> AnyPtr
213 | %foreign (ffi "pjrt_client_create")
214 | prim__pjrtClientCreate : AnyPtr -> AnyPtr -> PrimIO AnyPtr
216 | %foreign (ffi "PJRT_Client_Destroy_Args_delete")
217 | prim__deletePjrtClientDestroyArgs : AnyPtr -> PrimIO ()
219 | %foreign (ffi "PJRT_Client_Destroy_Args_new")
220 | prim__mkPjrtClientDestroyArgs : AnyPtr -> PrimIO AnyPtr
222 | %foreign (ffi "pjrt_client_destroy")
223 | prim__pjrtClientDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
225 | handleErrOnDestroy : HasIO io => AnyPtr -> AnyPtr -> String -> io ()
226 | handleErrOnDestroy api err target = unless (isNullPtr err) $
do
227 | msg <- pjrtErrorMessage api err
228 | args <- primIO $
prim__mkPjrtErrorGetCodeArgs err
229 | getCodeErr <- primIO $
prim__pjrtErrorGetCode api args
230 | if (isNullPtr getCodeErr) then do
231 | let code = pjrtErrorCodeFromCInt $
prim__pjrtErrorGetCodeArgsCode args
232 | printLn "WARN: Failed to destroy \{target} with error code \{show code}; message: \{msg}"
234 | printLn "WARN: Failed to fetch error code"
235 | printLn "WARN: Failed to destroy \{target} with unknown error code; message: \{msg}"
236 | destroyPjrtError api getCodeErr
237 | primIO $
prim__deletePjrtErrorGetCodeArgs args
238 | destroyPjrtError api err
244 | pjrtClientCreate : PjrtApi -> Pjrt PjrtClient
245 | pjrtClientCreate (MkPjrtApi api) = do
246 | args <- primIO prim__mkPjrtClientCreateArgs
247 | err <- primIO $
prim__pjrtClientCreate api args
248 | let client = prim__pjrtClientCreateArgsClient args
249 | primIO $
prim__deletePjrtClientCreateArgs args
251 | client <- onCollectAny' client destroy
252 | pure $
MkPjrtClient client
256 | destroy : AnyPtr -> IO ()
257 | destroy client = do
258 | args <- primIO $
prim__mkPjrtClientDestroyArgs client
259 | err <- primIO $
prim__pjrtClientDestroy api args
260 | primIO $
prim__deletePjrtClientDestroyArgs args
261 | handleErrOnDestroy api err "PJRT_Client"
265 | data PjrtProgram = MkPjrtProgram AnyPtr
267 | %foreign (ffi "PJRT_Program_delete")
268 | prim__deletePjrtProgram : AnyPtr -> PrimIO ()
270 | namespace PjrtProgram
272 | delete : HasIO io => PjrtProgram -> io ()
273 | delete (MkPjrtProgram p) = primIO $
prim__deletePjrtProgram p
275 | %foreign (ffi "PJRT_Program_new")
276 | prim__mkPjrtProgram : Ptr Char -> Bits64 -> PrimIO AnyPtr
283 | mkPjrtProgram : HasIO io => CppString -> io PjrtProgram
284 | mkPjrtProgram (MkCppString code) = MkPjrtProgram <$> (
285 | primIO $
prim__mkPjrtProgram (prim__stringData code) (prim__stringSize code)
288 | %foreign (ffi "PJRT_Client_Compile_Args_delete")
289 | prim__deletePjrtClientCompileArgs : AnyPtr -> PrimIO ()
291 | %foreign (ffi "PJRT_Client_Compile_Args_new")
292 | prim__mkPjrtClientCompileArgs : GCAnyPtr -> AnyPtr -> Ptr Char -> Bits64 -> PrimIO AnyPtr
294 | %foreign (ffi "PJRT_Client_Compile_Args_executable")
295 | prim__pjrtClientCompileArgsExecutable : AnyPtr -> AnyPtr
297 | %foreign (ffi "pjrt_client_compile")
298 | prim__pjrtClientCompile : AnyPtr -> AnyPtr -> PrimIO AnyPtr
300 | %foreign (ffi "PJRT_LoadedExecutable_Destroy_Args_delete")
301 | prim__deletePjrtLoadedExecutableDestroyArgs : AnyPtr -> PrimIO ()
303 | %foreign (ffi "PJRT_LoadedExecutable_Destroy_Args_new")
304 | prim__mkPjrtLoadedExecutableDestroyArgs : AnyPtr -> PrimIO AnyPtr
306 | %foreign (ffi "pjrt_loadedexecutable_destroy")
307 | prim__pjrtLoadedExecutableDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
311 | data PjrtLoadedExecutable = MkPjrtLoadedExecutable AnyPtr
315 | pjrtLoadedExecutableDestroy : HasIO io => PjrtApi -> PjrtLoadedExecutable -> io ()
316 | pjrtLoadedExecutableDestroy (MkPjrtApi api) (MkPjrtLoadedExecutable executable) = do
317 | args <- primIO $
prim__mkPjrtLoadedExecutableDestroyArgs executable
318 | err <- primIO $
prim__pjrtLoadedExecutableDestroy api args
319 | primIO $
prim__deletePjrtLoadedExecutableDestroyArgs args
320 | handleErrOnDestroy api err "PJRT_LoadedExecutable"
326 | pjrtClientCompile : PjrtApi -> PjrtClient -> PjrtProgram -> CppString -> Pjrt PjrtLoadedExecutable
328 | (MkPjrtApi api) (MkPjrtClient client) (MkPjrtProgram program) (MkCppString options) = do
329 | args <- primIO $
prim__mkPjrtClientCompileArgs
330 | client program (prim__stringData options) (prim__stringSize options)
331 | err <- primIO $
prim__pjrtClientCompile api args
332 | let executable = prim__pjrtClientCompileArgsExecutable args
333 | primIO $
prim__deletePjrtClientCompileArgs args
334 | try api err $
MkPjrtLoadedExecutable executable
336 | %foreign (ffi "PJRT_ExecuteOptions_delete")
337 | prim__deletePjrtExecuteOptions : AnyPtr -> PrimIO ()
339 | %foreign (ffi "PJRT_ExecuteOptions_new")
340 | prim__mkPjrtExecuteOptions : PrimIO AnyPtr
342 | %foreign (ffi "PJRT_LoadedExecutable_Execute_Args_delete")
343 | prim__deletePjrtLoadedExecutableExecuteArgs : AnyPtr -> PrimIO ()
345 | %foreign (ffi "PJRT_LoadedExecutable_Execute_Args_new")
346 | prim__mkPjrtLoadedExecutableExecuteArgs : AnyPtr -> AnyPtr -> AnyPtr -> PrimIO AnyPtr
348 | %foreign (ffi "pjrt_loadedexecutable_execute")
349 | prim__pjrtLoadedExecutableExecute : AnyPtr -> AnyPtr -> PrimIO AnyPtr
351 | %foreign (ffi "PJRT_Buffer_Destroy_Args_delete")
352 | prim__deletePjrtBufferDestroyArgs : AnyPtr -> PrimIO ()
354 | %foreign (ffi "PJRT_Buffer_Destroy_Args_new")
355 | prim__mkPjrtBufferDestroyArgs : AnyPtr -> PrimIO AnyPtr
357 | %foreign (ffi "pjrt_buffer_destroy")
358 | prim__pjrtBufferDestroy : AnyPtr -> AnyPtr -> PrimIO AnyPtr
362 | data PjrtBuffer = MkPjrtBuffer AnyPtr
366 | pjrtBufferDestroy : HasIO io => PjrtApi -> PjrtBuffer -> io ()
367 | pjrtBufferDestroy (MkPjrtApi api) (MkPjrtBuffer buffer) = do
368 | args <- primIO $
prim__mkPjrtBufferDestroyArgs buffer
369 | err <- primIO $
prim__pjrtBufferDestroy api args
370 | primIO $
prim__deletePjrtBufferDestroyArgs args
371 | handleErrOnDestroy api err "PJRT_Buffer"
377 | pjrtLoadedExecutableExecute :
378 | PjrtApi -> PjrtLoadedExecutable -> (outputs : Nat) -> Pjrt (Vect outputs PjrtBuffer)
379 | pjrtLoadedExecutableExecute (MkPjrtApi api) (MkPjrtLoadedExecutable executable) outputs = do
380 | outputListsInner <- malloc (cast outputs * cast sizeofPtr)
381 | outputLists <- malloc $
cast sizeofPtr
382 | primIO $
prim__setArrayPtr outputLists 0 outputListsInner
383 | options <- primIO prim__mkPjrtExecuteOptions
384 | args <- primIO $
prim__mkPjrtLoadedExecutableExecuteArgs executable options outputLists
385 | err <- primIO $
prim__pjrtLoadedExecutableExecute api args
386 | primIO $
prim__deletePjrtLoadedExecutableExecuteArgs args
387 | primIO $
prim__deletePjrtExecuteOptions options
388 | let buffers = map (\o => MkPjrtBuffer $
prim__index (cast o) outputListsInner) (range outputs)
390 | free outputListsInner
391 | try api err buffers
393 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_delete")
394 | prim__deletePjrtBufferToHostBufferArgs : AnyPtr -> PrimIO ()
396 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_new")
397 | prim__mkPjrtBufferToHostBufferArgs : AnyPtr -> AnyPtr -> Int -> PrimIO AnyPtr
399 | %foreign (ffi "PJRT_Buffer_ToHostBuffer_Args_event")
400 | prim__pjrtBufferToHostBufferArgsEvent : AnyPtr -> AnyPtr
402 | %foreign (ffi "pjrt_buffer_tohostbuffer")
403 | prim__pjrtBufferToHostBuffer : AnyPtr -> AnyPtr -> PrimIO AnyPtr
407 | pjrtEventDestroy : HasIO io => PjrtApi -> PjrtEvent -> io ()
408 | pjrtEventDestroy (MkPjrtApi api) (MkPjrtEvent event) = do
409 | args <- primIO $
prim__mkPjrtEventDestroyArgs event
410 | err <- primIO $
prim__pjrtEventDestroy api args
411 | primIO $
prim__deletePjrtEventDestroyArgs args
412 | handleErrOnDestroy api err "PJRT_Event"
418 | pjrtBufferToHostBuffer : PjrtApi -> PjrtBuffer -> Literal -> Pjrt PjrtEvent
419 | pjrtBufferToHostBuffer (MkPjrtApi api) (MkPjrtBuffer buffer) (MkLiteral literal) = do
420 | let untypedData = prim__literalUntypedData literal
421 | sizeBytes = prim__literalSizeBytes literal
422 | args <- primIO $
prim__mkPjrtBufferToHostBufferArgs buffer untypedData sizeBytes
423 | err <- primIO $
prim__pjrtBufferToHostBuffer api args
424 | let event = prim__pjrtBufferToHostBufferArgsEvent args
425 | primIO $
prim__deletePjrtBufferToHostBufferArgs args
426 | try api err $
MkPjrtEvent event