17 | module Compiler.LiteralRW
19 | import Compiler.Xla.XlaData
20 | import public Compiler.Xla.Literal
21 | import Compiler.Xla.Shape
22 | import Compiler.Xla.ShapeUtil
26 | range : (n : Nat) -> Literal [n] Nat
29 | impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat
30 | impl Z xs = rewrite plusZeroRightNeutral q in xs
31 | impl (S p) xs = rewrite sym $
plusSuccRightSucc q p in impl p (Scalar p :: xs)
33 | indexed : {shape : _} -> Literal shape (List Nat)
34 | indexed = go shape []
36 | concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a
38 | concat (Scalar x :: xs) = x :: concat xs
40 | go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat)
41 | go [] idxs = Scalar idxs
43 | go (S d :: ds) idxs = concat $
map (\i => go ds (snoc idxs i)) (range (S d))
46 | interface Primitive dtype => LiteralRW dtype ty where
47 | set : Literal -> List Nat -> ty -> IO ()
48 | get : Literal -> List Nat -> ty
51 | write : HasIO io => LiteralRW dtype a => {shape : _} -> Literal shape a -> io Literal
52 | write xs = liftIO $
do
53 | shape <- mkShape {dtype} shape
54 | literal <- allocLiteral shape
55 | let set = set {dtype} literal
56 | sequence_ [| set indexed xs |]
60 | read : LiteralRW dtype a => HasIO io => {shape : _} -> Literal -> io $
Literal shape a
61 | read lit = pure $
get {dtype} lit <$> indexed {shape}
64 | LiteralRW PRED Bool where
69 | LiteralRW F64 Double where
74 | LiteralRW S32 Int32 where
79 | LiteralRW U32 Nat where
84 | LiteralRW U64 Nat where