17 | module Compiler.Xla.ShapeUtil
20 | import Compiler.Xla.Shape
21 | import Compiler.Xla.XlaData
25 | ffi : String -> String
26 | ffi = libxla "c/xla/shape_util.h"
28 | %foreign (ffi "MakeShape")
29 | prim__mkShape : Int -> GCPtr Int64 -> Bits64 -> PrimIO AnyPtr
32 | mkShape : HasIO io => Types.Shape -> DType -> io Xla.Shape
33 | mkShape shape dtype = do
34 | let dtypeEnum = xlaIdent dtype
35 | MkInt64Array shape shapeLen <- mkInt64Array $
cast <$> shape
36 | shapePtr <- primIO $
prim__mkShape dtypeEnum shape shapeLen
37 | shapePtr <- onCollectAny' shapePtr Shape.delete
38 | pure (MkShape shapePtr)