0 | {--
 1 | Copyright (C) 2022  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| For internal spidr use only.
17 | module Compiler.LiteralRW
18 |
19 | import Compiler.Xla.XlaData
20 | import public Compiler.Xla.Literal
21 | import Compiler.Xla.Shape
22 | import Compiler.Xla.ShapeUtil
23 | import Literal
24 | import Util
25 |
26 | range : (n : Nat) -> Literal [n] Nat
27 | range n = impl n []
28 |   where
29 |   impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat
30 |   impl Z xs = rewrite plusZeroRightNeutral q in xs
31 |   impl (S p) xs = rewrite sym $ plusSuccRightSucc q p in impl p (Scalar p :: xs)
32 |
33 | indexed : {shape : _} -> Literal shape (List Nat)
34 | indexed = go shape []
35 |   where
36 |   concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a
37 |   concat [] = []
38 |   concat (Scalar x :: xs) = x :: concat xs
39 |
40 |   go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat)
41 |   go [] idxs = Scalar idxs
42 |   go (0 :: _) _ = []
43 |   go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d))
44 |
45 | public export
46 | interface Primitive dtype => LiteralRW dtype ty where
47 |   set : Literal -> List Nat -> ty -> IO ()
48 |   get : Literal -> List Nat -> ty
49 |
50 | export
51 | write : HasIO io => LiteralRW dtype a => {shape : _} -> Literal shape a -> io Literal
52 | write xs = liftIO $ do
53 |   shape <- mkShape {dtype} shape
54 |   literal <- allocLiteral shape
55 |   let set = set {dtype} literal
56 |   sequence_ [| set indexed xs |]
57 |   pure literal
58 |
59 | export
60 | read : LiteralRW dtype a => HasIO io => {shape : _} -> Literal -> io $ Literal shape a
61 | read lit = pure $ get {dtype} lit <$> indexed {shape}
62 |
63 | export
64 | LiteralRW PRED Bool where
65 |   set = set
66 |   get = get
67 |
68 | export
69 | LiteralRW F64 Double where
70 |   set = set
71 |   get = get
72 |
73 | export
74 | LiteralRW S32 Int32 where
75 |   set = set
76 |   get = get
77 |
78 | export
79 | LiteralRW U32 Nat where
80 |   set = UInt32t.set
81 |   get = UInt32t.get
82 |
83 | export
84 | LiteralRW U64 Nat where
85 |   set = UInt64t.set
86 |   get = UInt64t.get
87 |