17 | module Compiler.LiteralRW
19 | import Compiler.MLIR.IR.MLIRContext
20 | import Compiler.MLIR.IR.Types
21 | import Compiler.MLIR.IR.BuiltinTypes
22 | import Compiler.Xla.XlaData
23 | import public Compiler.Xla.Literal
24 | import Compiler.Xla.Shape
25 | import Compiler.Xla.ShapeUtil
31 | range : (n : Nat) -> Literal [n] Nat
34 | impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat
35 | impl Z xs = rewrite plusZeroRightNeutral q in xs
36 | impl (S p) xs = rewrite sym $
plusSuccRightSucc q p in impl p (Scalar p :: xs)
38 | indexed : {shape : _} -> Literal shape (List Nat)
39 | indexed = go shape []
41 | concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a
43 | concat (Scalar x :: xs) = x :: concat xs
45 | go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat)
46 | go [] idxs = Scalar idxs
48 | go (S d :: ds) idxs = concat $
map (\i => go ds (snoc idxs i)) (range (S d))
51 | set : (dtype : DType) -> Literal -> Int64Array -> idrisType dtype -> IO ()
52 | set PRED lit idx x = set lit idx x
53 | set S32 lit idx x = set lit idx x
54 | set S64 lit idx x = set lit idx x
55 | set U32 lit idx x = set lit idx x
56 | set U64 lit idx x = set lit idx x
57 | set F64 lit idx x = set lit idx x
60 | get : (dtype : DType) -> Literal -> Int64Array -> idrisType dtype
61 | get PRED lit idx = get lit idx
62 | get S32 lit idx = get lit idx
63 | get S64 lit idx = get lit idx
64 | get U32 lit idx = get lit idx
65 | get U64 lit idx = get lit idx
66 | get F64 lit idx = get lit idx
69 | write : HasIO io => {shape : _} -> (dtype : DType) -> Literal shape (idrisType dtype) -> io Literal
70 | write dtype xs = liftIO $
do
71 | shape <- mkShape shape dtype
72 | literal <- allocLiteral shape
73 | idxs <- traverse (mkInt64Array . map cast) indexed
74 | let set = set dtype literal
75 | sequence_ [| set idxs xs |]
79 | read : HasIO io => {shape : _} -> (dtype : DType) -> Literal -> io $
Literal shape (idrisType dtype)
81 | idxs <- traverse (mkInt64Array . map cast) indexed
82 | pure $
get dtype lit <$> idxs