0 | {--
 1 | Copyright (C) 2022  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| For internal spidr use only.
17 | module Compiler.LiteralRW
18 |
19 | import Compiler.MLIR.IR.MLIRContext
20 | import Compiler.MLIR.IR.Types
21 | import Compiler.MLIR.IR.BuiltinTypes
22 | import Compiler.Xla.XlaData
23 | import public Compiler.Xla.Literal
24 | import Compiler.Xla.Shape
25 | import Compiler.Xla.ShapeUtil
26 | import Compiler.FFI
27 | import Literal
28 | import DType
29 | import Util
30 |
31 | range : (n : Nat) -> Literal [n] Nat
32 | range n = impl n []
33 |   where
34 |   impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat
35 |   impl Z xs = rewrite plusZeroRightNeutral q in xs
36 |   impl (S p) xs = rewrite sym $ plusSuccRightSucc q p in impl p (Scalar p :: xs)
37 |
38 | indexed : {shape : _} -> Literal shape (List Nat)
39 | indexed = go shape []
40 |   where
41 |   concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a
42 |   concat [] = []
43 |   concat (Scalar x :: xs) = x :: concat xs
44 |
45 |   go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat)
46 |   go [] idxs = Scalar idxs
47 |   go (0 :: _) _ = []
48 |   go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d))
49 |
50 | export
51 | set : (dtype : DType) -> Literal -> Int64Array -> idrisType dtype -> IO ()
52 | set PRED lit idx x = set lit idx x
53 | set S32  lit idx x = set lit idx x
54 | set S64  lit idx x = set lit idx x
55 | set U32  lit idx x = set lit idx x
56 | set U64  lit idx x = set lit idx x
57 | set F64  lit idx x = set lit idx x
58 |
59 | export
60 | get : (dtype : DType) -> Literal -> Int64Array -> idrisType dtype
61 | get PRED lit idx = get lit idx
62 | get S32  lit idx = get lit idx
63 | get S64  lit idx = get lit idx
64 | get U32  lit idx = get lit idx
65 | get U64  lit idx = get lit idx
66 | get F64  lit idx = get lit idx
67 |
68 | export
69 | write : HasIO io => {shape : _} -> (dtype : DType) -> Literal shape (idrisType dtype) -> io Literal
70 | write dtype xs = liftIO $ do
71 |   shape <- mkShape shape dtype
72 |   literal <- allocLiteral shape
73 |   idxs <- traverse (mkInt64Array . map cast) indexed
74 |   let set = set dtype literal
75 |   sequence_ [| set idxs xs |]
76 |   pure literal
77 |
78 | export
79 | read : HasIO io => {shape : _} -> (dtype : DType) -> Literal -> io $ Literal shape (idrisType dtype)
80 | read dtype lit = do
81 |   idxs <- traverse (mkInt64Array . map cast) indexed
82 |   pure $ get dtype lit <$> idxs
83 |