17 | module Compiler.Xla.ShapeUtil
20 | import Compiler.Xla.Shape
21 | import Compiler.Xla.XlaData
24 | ffi : String -> String
25 | ffi = libxla "c/xla/shape_util.h"
27 | %foreign (ffi "MakeShape")
28 | prim__mkShape : Int -> GCPtr Int -> Int -> PrimIO AnyPtr
31 | mkShape : (HasIO io, Primitive dtype) => Types.Shape -> io Xla.Shape
33 | let dtypeEnum = xlaIdentifier {dtype}
34 | MkIntArray shapeArrayPtr <- mkIntArray shape
35 | shapePtr <- primIO $
prim__mkShape dtypeEnum shapeArrayPtr (cast $
length shape)
36 | shapePtr <- onCollectAny' shapePtr Shape.delete
37 | pure (MkShape shapePtr)