0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| Defines `Literal`, a single value or array of values with a specified shape.
 17 | ||| `Literal` is similar to `Tensor`, but differs in a number of important ways:
 18 | |||
 19 | ||| * `Literal` offers a convenient syntax for constructing `Literal`s with boolean and numeric
 20 | |||   contents. For example, `True`, `1` and `[1, 2, 3]` are all valid `Literal`s. This makes it
 21 | |||   useful for constructing `Tensor`s.
 22 | ||| * Operations on `Literal` are *not* accelerated by a graph compiler, so operations on large
 23 | |||   `Literal`s, and large sequences of operations on any `Literal`, can be expected to be slower
 24 | |||   than they would on an equivalent `Tensor`.
 25 | ||| * `Literal` is implemented in pure Idris. As such, values can contain elements of any type, and
 26 | |||   implements a number of standard Idris interfaces. This, along with its convenient syntax,
 27 | |||   makes it particularly useful for testing operations on `Tensor`s.
 28 | module Literal
 29 |
 30 | import public Types
 31 |
 32 | ||| A scalar or array of values.
 33 | public export
 34 | data Literal : Shape -> Type -> Type where
 35 |   Scalar : a -> Literal [] a
 36 |   Nil : Literal (0 :: ds) a
 37 |   (::) : Literal ds a -> Literal (d :: ds) a -> Literal (S d :: ds) a
 38 |
 39 | export
 40 | fromInteger : Integer -> Literal [] Int32
 41 | fromInteger = Scalar . cast {to = Int32}
 42 |
 43 | export
 44 | fromDouble : Double -> Literal [] Double
 45 | fromDouble = Scalar
 46 |
 47 | ||| Convenience aliases for scalar boolean literals.
 48 | export
 49 | True, False : Literal [] Bool
 50 | True = Scalar True
 51 | False = Scalar False
 52 |
 53 | export
 54 | Functor (Literal shape) where
 55 |   map f (Scalar x) = Scalar (f x)
 56 |   map _ [] = []
 57 |   map f (x :: xs) = map f x :: map f xs
 58 |
 59 | functorIdentity : (xs : Literal shape a) -> map Prelude.id xs = xs
 60 | functorIdentity (Scalar _) = Refl
 61 | functorIdentity [] = Refl
 62 | functorIdentity (x :: xs) = cong2 (::) (functorIdentity x) (functorIdentity xs)
 63 |
 64 | functorComposition :
 65 |   (xs : Literal shape a) -> (f : a -> b) -> (g : b -> c) -> map (g . f) xs = map g (map f xs)
 66 | functorComposition (Scalar _) _ _ = Refl
 67 | functorComposition [] _ _ = Refl
 68 | functorComposition (x :: xs) f g =
 69 |   cong2 (::) (functorComposition x f g) (functorComposition xs f g)
 70 |
 71 | export
 72 | {shape : _} -> Applicative (Literal shape) where
 73 |   pure x = case shape of
 74 |     [] => Scalar x
 75 |     (0 :: _) => []
 76 |     (S d :: ds) => pure x :: assert_total (pure x)
 77 |
 78 |   (Scalar f) <*> (Scalar x) = Scalar (f x)
 79 |   [] <*> [] = []
 80 |   (f :: fs) <*> (x :: xs) = (f <*> x) :: (fs <*> xs)
 81 |
 82 | applicativeIdentity : (xs : Literal shape a) -> pure Prelude.id <*> xs = xs
 83 | applicativeIdentity (Scalar _) = Refl
 84 | applicativeIdentity [] = Refl
 85 | applicativeIdentity (x :: xs) = cong2 (::) (applicativeIdentity x) (applicativeIdentity xs)
 86 |
 87 | applicativeHomomorphism :
 88 |   (shape : Shape) -> (f : a -> b) -> (x : a) -> pure f <*> pure x = pure {f = Literal shape} (f x)
 89 | applicativeHomomorphism [] _ _ = Refl
 90 | applicativeHomomorphism (d :: ds) f x = forCons d ds f x
 91 |   where
 92 |   forCons :
 93 |     (d : Nat) ->
 94 |     (ds : Shape) ->
 95 |     (f : a -> b) ->
 96 |     (x : a) ->
 97 |     pure f <*> pure x = pure {f = Literal (d :: ds)} (f x)
 98 |   forCons 0 _ _ _ = Refl
 99 |   forCons (S d) ds f x = cong2 (::) (applicativeHomomorphism ds f x) (forCons d ds f x)
100 |
101 | applicativeInterchange :
102 |   (fs : Literal shape (a -> b)) -> (x : a) -> fs <*> pure x = pure (x) <*> fs
103 | applicativeInterchange (Scalar _) _ = Refl
104 | applicativeInterchange [] _ = Refl
105 | applicativeInterchange (f :: fs) x =
106 |   cong2 (::) (applicativeInterchange f x) (applicativeInterchange fs x)
107 |
108 | applicativeComposition :
109 |   (xs : Literal shape a) ->
110 |   (fs : Literal shape (a -> b)) ->
111 |   (gs : Literal shape (b -> c)) ->
112 |   pure (.) <*> gs <*> fs <*> xs = gs <*> (fs <*> xs)
113 | applicativeComposition (Scalar _) (Scalar _) (Scalar _) = Refl
114 | applicativeComposition [] [] [] = Refl
115 | applicativeComposition (x :: xs) (f :: fs) (g :: gs) =
116 |   cong2 (::) (applicativeComposition x f g) (applicativeComposition xs fs gs)
117 |
118 | export
119 | Foldable (Literal shape) where
120 |   foldr f acc (Scalar x) = f x acc
121 |   foldr _ acc [] = acc
122 |   foldr f acc (x :: y) = foldr f (foldr f acc y) x
123 |
124 | export
125 | Traversable (Literal shape) where
126 |   traverse f (Scalar x) = [| Scalar (f x) |]
127 |   traverse f [] = pure []
128 |   traverse f (x :: xs) = [| traverse f x :: traverse f xs |]
129 |
130 | export
131 | Zippable (Literal shape) where
132 |   zipWith f (Scalar x) (Scalar y) = Scalar (f x y)
133 |   zipWith _ [] [] = []
134 |   zipWith f (x :: xs) (y :: ys) = zipWith f x y :: zipWith f xs ys
135 |
136 |   zipWith3 f (Scalar x) (Scalar y) (Scalar z) = Scalar (f x y z)
137 |   zipWith3 _ [] [] [] = []
138 |   zipWith3 f (x :: xs) (y :: ys) (z :: zs) = zipWith3 f x y z :: zipWith3 f xs ys zs
139 |
140 |   unzipWith f (Scalar x) = let (x, y) = f x in (Scalar x, Scalar y)
141 |   unzipWith _ [] = ([], [])
142 |   unzipWith f (x :: xs) =
143 |     let (x, y) = unzipWith f x
144 |         (xs, ys) = unzipWith f xs
145 |      in (x :: xs, y :: ys)
146 |
147 |   unzipWith3 f (Scalar x) = let (x, y, z) = f x in (Scalar x, Scalar y, Scalar z)
148 |   unzipWith3 _ [] = ([], [], [])
149 |   unzipWith3 f (x :: xs) =
150 |     let (x, y, z) = unzipWith3 f x
151 |         (xs, ys, zs) = unzipWith3 f xs
152 |      in (x :: xs, y :: ys, z :: zs)
153 |
154 | ||| `True` if no elements are `False`. `all []` is `True`.
155 | export
156 | all : Literal shape Bool -> Bool
157 | all xs = foldr (\x, y => x && y) True xs
158 |
159 | export
160 | Num a => Num (Literal [] a) where
161 |   x + y = [| x + y |]
162 |   x * y = [| x * y |]
163 |   fromInteger = Scalar . fromInteger
164 |
165 | export
166 | negate : Neg a => Literal shape a -> Literal shape a
167 | negate = map negate
168 |
169 | export
170 | Eq a => Eq (Literal shape a) where
171 |   x == y = all (zipWith (==) x y)
172 |
173 | toVect : Literal (d :: ds) a -> Vect d (Literal ds a)
174 | toVect [] = []
175 | toVect (x :: y) = x :: toVect y
176 |
177 | ||| Show the `Literal`. The `Scalar` constructor is omitted for brevity.
178 | export
179 | {shape : _} -> Show a => Show (Literal shape a) where
180 |   show = showWithIndent "" where
181 |     showWithIndent : {shape : _} -> String -> Literal shape a -> String
182 |     showWithIndent _ (Scalar x) = show x
183 |     showWithIndent _ [] = "[]"
184 |     showWithIndent {shape = [S _]} _ x = show (toList x)
185 |     showWithIndent {shape = (S d :: dd :: ddd)} indent (x :: xs) =
186 |       let indent = " " ++ indent
187 |           first = showWithIndent indent x
188 |           rest = foldMap (\e => ",\n" ++ indent ++ showWithIndent indent e) (toVect xs)
189 |        in "[" ++ first ++ rest ++ "]"
190 |
191 | export
192 | {shape : _} -> Cast (Array shape a) (Literal shape a) where
193 |   cast x with (shape)
194 |     cast x | [] = Scalar x
195 |     cast _ | (0 :: _) = []
196 |     cast (x :: xs) | (S d :: ds) = cast x :: cast xs
197 |
198 | export
199 | [toArray] Cast (Literal shape a) (Array shape a) where
200 |   cast (Scalar x) = x
201 |   cast [] = []
202 |   cast (x :: y) = cast @{toArray} x :: cast @{toArray} y
203 |
204 | namespace All
205 |   ||| An `All p xs` is an array (or scalar) of proofs about each element in `xs`.
206 |   |||
207 |   ||| For example, an `All IsSucc xs` proves that every element in `xs` is non-zero.
208 |   public export
209 |   data All : (0 p : a -> Type) -> Literal shape a -> Type where
210 |     Scalar : forall x . p x -> All p (Scalar x)
211 |     Nil  : All p []
212 |     (::) : All p x -> All p xs -> All p (x :: xs)
213 |
214 | namespace All2
215 |   ||| An `All2 p xs ys` is an array (or scalar) of pairwise proofs about elements in `xs` and `ys`.
216 |   |||
217 |   ||| For example, an `All2 LT xs ys` proves that each number in `xs` is less than the number in
218 |   ||| `ys` at the same position.
219 |   public export
220 |   data All2 : (p : a -> b -> Type) -> Literal shape a -> Literal shape b -> Type where
221 |     Scalar : forall a, b . p a b -> All2 p (Scalar a) (Scalar b)
222 |     Nil : All2 p [] []
223 |     (::) : All2 p a b -> All2 p as bs -> All2 p (a :: as) (b :: bs)
224 |