0 | module Misc
  1 |
  2 | import Data.Nat
  3 | import Data.List.Elem
  4 | import Data.Vect
  5 | import Data.Vect.Elem
  6 | import System.Random
  7 | import Data.Fin
  8 | import Data.List1
  9 | import Data.Fin.Arith
 10 | import Data.List.Quantifiers
 11 | import Data.Vect.Quantifiers
 12 | import Decidable.Equality
 13 | import Decidable.Equality.Core
 14 | import Data.List
 15 |
 16 | %hide Builtin.infixr.(#)
 17 | %hide Data.Vect.Quantifiers.All.index
 18 |
 19 | {-------------------------------------------------------------------------------
 20 | {-------------------------------------------------------------------------------
 21 | Various utilities necessary for TensorType, but that don't fit anywhere else
 22 | Does not depend on any other file within this project
 23 |
 24 | Some of these feel like they should be in the Idris standard library
 25 |
 26 | -------------------------------------------------------------------------------}
 27 | -------------------------------------------------------------------------------}
 28 |
 29 | namespace IsNo
 30 |   ||| The proof that a decidable property leads to a contradiction
 31 |   ||| `IsNo` is a type Idris can automatically synthesise, unlike `Not`
 32 |   ||| See example below
 33 |   public export
 34 |   data IsNo : Dec a -> Type where
 35 |     ItIsNo : {prop : Type} -> 
 36 |       {contra : Not prop} ->
 37 |       IsNo (No {prop=prop} contra)
 38 |
 39 |   failing 
 40 |     thisOneFails : Not ("i" = "j")
 41 |     thisOneFails = %search
 42 |   
 43 |   thisOneDoesnt : IsNo (decEq "i" "j")
 44 |   thisOneDoesnt = %search
 45 |
 46 |   public export
 47 |   [UninhabitedIsNoRefl] {x : a} -> DecEq a =>
 48 |     Uninhabited (IsNo (decEq x x)) where
 49 |     uninhabited y with (decEq x x)
 50 |       _ | (Yes _) with (y)
 51 |         _ | ItIsNo impossible
 52 |       _ | (No contra) = contra Refl
 53 |   
 54 |   public export 
 55 |   isNoSym : DecEq a => {x, y : a} -> IsNo (decEq x y) -> IsNo (decEq y x)
 56 |   isNoSym z with (decEq x y) | (decEq y x)
 57 |     _ | (No contra1) | (Yes prf) = absurd (contra1 (sym prf))
 58 |     _ | _           | (No contra) = ItIsNo 
 59 |   
 60 |   ||| Proof of inequality yields IsNo
 61 |   public export
 62 |   proofIneqIsNo : {x, y : a} -> DecEq a =>
 63 |     Not (x = y) -> IsNo (decEq x y)
 64 |   proofIneqIsNo f with (decEq x y)
 65 |     _ | (Yes prf) = absurd (f prf)
 66 |     _ | (No contra) = ItIsNo
 67 |
 68 | namespace Maybe
 69 |   public export
 70 |   data IsNothing : Maybe a -> Type where
 71 |     ItIsNothing : IsNothing Nothing
 72 |
 73 |   public export
 74 |   maybeVoidIsNothing : (x : Maybe Void) -> IsNothing x
 75 |   maybeVoidIsNothing Nothing = ItIsNothing
 76 |   maybeVoidIsNothing (Just v) = absurd v
 77 |
 78 |   public export
 79 |   Uninhabited (IsNothing (Just x)) where
 80 |     uninhabited ItIsNothing impossible
 81 |
 82 |
 83 | namespace NotElem
 84 |   public export
 85 |   data NotElem : DecEq a => (x : a) -> (xs : Vect n a) -> Type where
 86 |     NotInEmptyVect : DecEq a => {0 x : a} -> NotElem x []
 87 |     NotInNonEmptyVect : DecEq a => {0 x, y : a} ->
 88 |       (xs : Vect n a) ->
 89 |       IsNo (decEq x y) ->
 90 |       (ne : NotElem x xs) =>
 91 |       NotElem x (y :: xs)
 92 |   
 93 |   public export
 94 |   notEqualNotElem : DecEq a =>
 95 |     {0 x, y : a} ->
 96 |     (neq : IsNo (decEq x y)) ->
 97 |     NotElem x [y]
 98 |   notEqualNotElem neq = NotInNonEmptyVect [] neq
 99 |   
100 |   ||| If an element `i` is not in the singleton list `[j]`, then `j` is not in
101 |   ||| the singleton list `[i]`
102 |   public export
103 |   notElemSym : DecEq a => {i, j : a} -> NotElem i [j] -> NotElem j [i]
104 |   notElemSym (NotInNonEmptyVect [] isNo) = notEqualNotElem (isNoSym isNo)
105 |   
106 |   ||| If an element `i` is in the singleton list `[j]`, then `j` is in the 
107 |   ||| singleton list `[i]`
108 |   public export
109 |   elemSym : DecEq a => {i, j : a} -> Vect.Elem.Elem i [j] ->
110 |     Vect.Elem.Elem j [i]
111 |   elemSym Here = Here
112 |
113 |
114 | namespace Applicative
115 |   ||| Definition of liftA2 in terms of (<*>)
116 |   public export
117 |   liftA2 : Applicative f => f a -> f b -> f (a, b)
118 |   liftA2 fa fb = ((,) <$> fa) <*> fb
119 |   
120 |   ||| Tensorial strength
121 |   public export
122 |   strength : Applicative f => a -> f b -> f (a, b)
123 |   strength a fb = liftA2 (pure a) fb
124 |   
125 |   ||| Pointwise Num structure for Applicative functors
126 |   public export
127 |   [applicativeNum] Num a => Applicative f => Num (f a) where
128 |     xs + ys = uncurry (+) <$> liftA2 xs ys
129 |     xs * ys = uncurry (*) <$> liftA2 xs ys
130 |     fromInteger = pure . fromInteger
131 |
132 |
133 | namespace VectFoldable
134 |   ||| Implementation of Foldable for Vect that is denotationally equivalent to
135 |   ||| one in Data.Vect, but which does not use `foldrImpl` and therefore
136 |   ||| reduces in the typechecker
137 |   public export
138 |   [straightforward] Foldable (Vect n) where
139 |     foldr f z [] = z
140 |     foldr f z (x :: xs) = f x (foldr f z xs)
141 |
142 |   ||| toList with a different foldable implementation
143 |   public export
144 |   toList' : Vect n a -> List a
145 |   toList' = foldr @{straightforward} (::) []
146 |
147 |   public export
148 |   fromList' : (xs : List a) -> Vect (length xs) a
149 |   fromList' [] = []
150 |   fromList' (x :: xs) = x :: fromList' xs
151 |
152 | ||| Duplicate of utilities for Data.Vect in their Naperian form
153 | namespace Vect
154 |   public export
155 |   sum : Num a => Vect n a -> a
156 |   sum xs = foldr @{straightforward} (+) (fromInteger 0) xs
157 |   
158 |   -- Because of the way foldr for Vect is implemented in Idris 
159 |   -- we have to use this approach below, otherwise allSuccThenProdSucc breaks
160 |   public export 
161 |   prod : Num a => Vect n a -> a
162 |   prod xs = foldr @{straightforward} (*) (fromInteger 1) xs
163 |   -- prod [] = fromInteger 1
164 |   -- prod (x :: xs) = x * prod xs
165 |
166 |   public export
167 |   max : Ord a => Vect n a -> Maybe a
168 |   max [] = Nothing
169 |   max (x :: xs) = case max xs of
170 |     Nothing => Just x
171 |     Just y => Just (max x y)
172 |
173 |   public export
174 |   argmax : Ord a => IsSucc n => Vect n a -> Fin n 
175 |   argmax [x] = FZ
176 |   argmax (x :: x' :: xs) =
177 |     let maxRest = argmax (x' :: xs)
178 |     in case x > index maxRest (x' :: xs) of 
179 |       True => FZ
180 |       False => FS maxRest
181 |   
182 |   public export
183 |   argmin : Ord a => IsSucc n => Vect n a -> Fin n
184 |   argmin = argmax @{Reverse} 
185 |   
186 |   ||| Dual to concat from Data.Vect
187 |   public export
188 |   unConcat : {n, m : Nat} -> Vect (n * m) a -> Vect n (Vect m a)
189 |   unConcat {n = 0} _ = []
190 |   unConcat {n = (S k)} xs = let (f, s) = splitAt m xs
191 |                             in f :: unConcat s
192 |
193 |   ||| Trim a specified trailing value
194 |   public export
195 |   dropFromEnd : Eq a => a -> Vect n a -> List a
196 |   dropFromEnd c row = reverse (dropWhile (== c) (reverse (toList row)))
197 |   
198 |   ||| Combination of `cons` and `snoc`: adds an element in front, and at the end
199 |   public export
200 |   consSnoc : Vect n a -> a -> a -> Vect (2 + n) a
201 |   consSnoc xs a b = a :: snoc xs b
202 |   
203 |   ||| Pad a vector with a specified element to exactly `targetSize`
204 |   public export
205 |   padToSize : Vect size a -> (targetSize : Nat) -> a ->
206 |     LTE size targetSize => 
207 |     Vect targetSize a
208 |   padToSize [] Z c = []
209 |   padToSize [] (S k) c = c :: padToSize [] k c
210 |   padToSize (x :: xs) (S k) c = x :: padToSize xs k c @{fromLteSucc %search}
211 |
212 |   ||| Drop the first i elements of a vector
213 |   ||| Analogous to Data.Vect.drop, except the index is Fin n instead of Nat
214 |   public export
215 |   drop : (i : Fin (S n)) -> Vect n a -> Vect (minus n (finToNat i)) a
216 |   drop FZ xs = rewrite minusZeroRight n in xs
217 |   drop (FS i) (x :: xs) = drop i xs
218 |   
219 |   namespace DropElem
220 |     ||| Drop all the elements up and until the element `x` from a vector
221 |     public export
222 |     drop : DecEq a =>
223 |       (xs : Vect n a) ->
224 |       (elem : Elem x xs) ->
225 |       Vect (n `minus` (finToNat (FS (elemToFin elem)))) a
226 |     drop {n=S k} (_ :: xs) Here = rewrite minusZeroRight k in xs
227 |     drop (_ :: xs) (There later) = drop xs later
228 |
229 |
230 | namespace List
231 |   public export
232 |   sum : Num a => List a -> a
233 |   sum = foldr (+) (fromInteger 0) 
234 |
235 |   public export
236 |   prod : Num a => List a -> a
237 |   prod = foldr (*) (fromInteger 1)
238 |
239 |   public export
240 |   listZip : List a -> List b -> List (a, b)
241 |   listZip (x :: xs) (y :: ys) = (x, y) :: listZip xs ys
242 |   listZip _ _ = []
243 |
244 |   ||| Map each element along with its zero-based position in the list.
245 |   public export
246 |   mapWithIndex : (Nat -> a -> b) -> List a -> List b
247 |   mapWithIndex f = go 0
248 |     where
249 |       go : Nat -> List a -> List b
250 |       go _ []        = []
251 |       go i (x :: xs) = f i x :: go (S i) xs
252 |
253 |   ||| Split a list into consecutive chunks of size `n` (clamped to at least 1).
254 |   ||| The final chunk may be shorter than `n`, this is why the length of the
255 |   ||| list is needed as upper bound.
256 |   public export
257 |   chunksOf : (n : Nat) -> List a -> List (List a)
258 |   chunksOf n xs = go (max 1 n) xs (length xs)
259 |     where
260 |       go : Nat -> List a -> (len : Nat) -> List (List a)
261 |       go _  []          _     = []
262 |       go _  ys@(_ :: _) Z     = [ys]
263 |       go sz ys@(_ :: _) (S f) = case splitAt sz ys of
264 |                                   (h, t) => h :: go sz t f
265 |   
266 |   public export
267 |   max : Ord a => List a -> Maybe a
268 |   max [] = Nothing
269 |   max (x :: xs) = case max xs of
270 |     Nothing => Just x
271 |     Just y => Just (max x y)
272 |
273 |   namespace NonEmpty
274 |     public export
275 |     max : Ord a => (xs : List a) -> (ne : NonEmpty xs) => a
276 |     max [x] {ne=IsNonEmpty} = x
277 |     max (x :: y :: xs) {ne=IsNonEmpty} = max x (max (y :: xs))
278 |
279 |   ||| Trim a specified trailing value
280 |   public export
281 |   dropFromEnd : Eq a => a -> List a -> List a
282 |   dropFromEnd c row = reverse (dropWhile (== c) (reverse row))
283 |
284 |   ||| Combination of `cons` and `snoc`: adds an element in front, and at the end
285 |   public export
286 |   consSnoc : List a -> a -> a -> List a
287 |   consSnoc xs x y = x :: snoc xs y
288 |
289 |   ||| Pad a list with a specified element to at least `targetSize`
290 |   public export
291 |   padToSize : Nat -> a -> List a -> List a
292 |   padToSize targetSize padValue xs =
293 |     xs ++ replicate (minus targetSize (length xs)) padValue
294 |
295 |
296 |   ||| Drop all the elements after the element `x` from a list
297 |   public export
298 |   dropAfterElem : (xs : List a) -> (elem : Elem x xs) -> List a
299 |   dropAfterElem (x :: _) Here = [x]
300 |   dropAfterElem (y :: xs) (There p) = y :: dropAfterElem xs p
301 |
302 | namespace VectNaperianUtils
303 |   ||| Analogue of `(::)`
304 |   public export
305 |   cons : x -> (Fin l -> x) -> (Fin (S l) -> x)
306 |   cons x _ FZ = x
307 |   cons _ f (FS k') = f k'
308 |
309 |   -- dcons : x -> ((i : Fin k) -> i' i) -> ((i : Fin (S k)) -> i' (cons x i'))
310 |   
311 |   public export
312 |   head : (Fin (S l) -> x) -> x
313 |   head f = f FZ
314 |   
315 |   public export
316 |   tail : (Fin (S l) -> x) -> (Fin l -> x)
317 |   tail f = f . FS
318 |   
319 |   ||| All but the last element
320 |   public export
321 |   init : (Fin (S n) -> a) -> Fin n -> a
322 |   init f x = f (weaken x)
323 |   
324 |   ||| Analogus to `Data.Vect.take`
325 |   public export 
326 |   takeFin : (s : Fin (S n)) -> Vect n a -> Vect (finToNat s) a
327 |   takeFin FZ _ = []
328 |   takeFin (FS s) (x :: xs) = x :: takeFin s xs
329 |
330 |   public export
331 |   sum : Num a => {n : Nat} -> (Fin n -> a) -> a
332 |   sum {n = 0} _ = 0
333 |   sum {n = (S k)} content = content FZ + sum (content . FS)
334 |
335 |   public export
336 |   prod : Num a => {n : Nat} -> (Fin n -> a) -> a
337 |   prod = prod . tabulate
338 |
339 |   public export
340 |   toList : {n : Nat} -> (Fin n -> a) -> List a
341 |   toList = toList' . tabulate
342 |
343 | namespace FinArithmetic
344 |   public export
345 |   minusSuccLTE : {n, m : Nat} -> LTE n m ->
346 |     minus (S m) n = S (minus m n)
347 |   minusSuccLTE {m = 0, n = 0} LTEZero = Refl
348 |   minusSuccLTE {m = (S k), n = 0} LTEZero = Refl
349 |   minusSuccLTE {m = (S k), n = (S left)} (LTESucc x) = minusSuccLTE x
350 |
351 |   ||| Like weakenN from Data.Fin, but where n is on the other side of +
352 |   public export
353 |   weakenN' : (0 n : Nat) -> Fin m -> Fin (n + m)
354 |   weakenN' n x = rewrite plusCommutative n m in weakenN n x
355 |
356 |   ||| Like weakenN, but with mutliplication
357 |   ||| Like shiftMul, but without changing the value of the index
358 |   public export
359 |   weakenMultN : {n : Nat} ->
360 |     (m : Nat) -> {auto prf : IsSucc m} ->
361 |     (i : Fin n) -> Fin (m * n)
362 |   weakenMultN (S 0) {prf = ItIsSucc} i = rewrite multOneLeftNeutral n in i
363 |   weakenMultN (S (S k)) {prf = ItIsSucc} i = weakenN' n (weakenMultN (S k) i)
364 |
365 |   multRightUnit : (m : Nat) -> m * 1 = m
366 |   multRightUnit 0 = Refl
367 |   multRightUnit (S k) = cong S (multRightUnit k)
368 |
369 |   multRightZeroCancel : (m : Nat) -> m * 0 = 0
370 |   multRightZeroCancel 0 = Refl
371 |   multRightZeroCancel (S k) = multRightZeroCancel k
372 |
373 |   ||| Variant of `shift` from Data.Fin, but with multiplication
374 |   ||| Given an index i : Fin n, it recasts it as one where steps are stride sized
375 |   ||| That is, returns stride * i : Fin (stride * n)
376 |   ||| Implemented by recursing on i, adding stride each time
377 |   public export
378 |   shiftMul : {n : Nat} ->
379 |     (stride : Nat) -> {auto prf : IsSucc stride} ->
380 |     (i : Fin n) -> Fin (n * stride)
381 |   shiftMul (S s) {prf = ItIsSucc} FZ = FZ
382 |   shiftMul stride (FS i) = shift stride (shiftMul stride i)
383 |
384 |   shiftMulTest : shiftMul {n=3} 5 1 = 5
385 |   shiftMulTest = Refl
386 |
387 |   ||| Analogous to strengthen from Data.Fin
388 |   ||| Attempts to strengthen the bound on Fin (m + n) to Fin m
389 |   ||| If it doesn't succeed, then returns the remainder in Fin n
390 |   public export
391 |   strengthenN : {m, n : Nat} -> Fin (m + n) -> Either (Fin m) (Fin n)
392 |   strengthenN {m = 0} x = Right x
393 |   strengthenN {m = (S k)} FZ = Left FZ
394 |   strengthenN {m = (S k)} (FS x) with (strengthenN x)
395 |     _ | (Left p) = Left $ FS p
396 |     _ | (Right q) = Right q
397 |   -- strengthenN {m = 0} x = Nothing
398 |   -- strengthenN {m = (S k)} FZ = Just FZ
399 |   -- strengthenN {m = (S k)} (FS x) with (strengthenN x)
400 |   --   _ | Nothing = Nothing
401 |   --   _ | (Just p) = Just $ FS p
402 |     --= let t = strengthenN x
403 |     --  in ?strengthenN_rhs_3
404 |
405 |   -- strengthenN {n = 0} x = Just x
406 |   -- strengthenN {m = 0} {n = (S k)} FZ = Nothing
407 |   -- strengthenN {m = (S j)} {n = (S k)} FZ = Just FZ
408 |   -- strengthenN {m} {n = (S k)} (FS x)
409 |   --   = let t = strengthenN x
410 |   --         v = Fin.FS
411 |   --     in ?what -- strengthenN x
412 |
413 |
414 |   --         restCount = indexCount is -- fpn = 13 : Fin (20)
415 |   -- iCTest1 : indexCount {shape = [3, 4, 5]} [1, 2, 3] = 33
416 |   -- iCTest1 = ?iCTest_rhs
417 |   
418 |   ||| Like finS, but without wrapping
419 |   ||| finS' last = last
420 |   public export
421 |   finS' : {n : Nat} -> Fin n -> Fin n
422 |   finS' {n = 1} x = x
423 |   finS' {n = (S (S k))} FZ = FS FZ
424 |   finS' {n = (S (S k))} (FS x) = FS $ finS' x
425 |   --finS' {n = S _} x = case strengthen x of
426 |   --    Nothing => x
427 |   --    Just y => FS y
428 |
429 |   
430 |   ||| Adds two Fin n, and bounds the result
431 |   ||| Meaning (93:Fin 5) + (4 : Fin 5) = 4
432 |   public export
433 |   addFinsBounded : {n : Nat} -> Fin n -> Fin n -> Fin n
434 |   addFinsBounded x FZ = x
435 |   addFinsBounded x (FS y) = addFinsBounded (finS' x) (weaken y)
436 |
437 |   finSTest : finS' {n = 5} 3 = 4
438 |   finSTest = Refl
439 |
440 |   finSTest2 : finS' {n = 5} 4 = 4
441 |   finSTest2 = Refl
442 |
443 |   ||| Divides a Fin by 2, rounding down
444 |   public export
445 |   half : (k : Fin n) -> Fin n
446 |   half FZ = FZ
447 |   half (FS FZ) = FZ
448 |   half (FS (FS x)) = weakenN' 2 x
449 |
450 |   ||| Computes the midway index between two bounds
451 |   public export
452 |   mid : (low : Fin n) -> (high : Fin n) ->
453 |     {auto prf : So (high >= low)} ->
454 |     Fin n
455 |   mid FZ high = half high
456 |   mid (FS x) (FS y) {prf} = FS (mid x y)
457 |
458 | ||| Given a non-empty sorted vector `xs`, finds the "right bin", i.e. the index 
459 | ||| of the smallest element between the bounds that `x` is not bigger than.
460 | ||| If `x` is bigger than the highest element, returns `Nothing`
461 | ||| `findBinBetween [2,7,10] 1 0 2 = Just 0`
462 | ||| `findBinBetween [2,7,10] 3 0 2 = Just 1`
463 | ||| `findBinBetween [2,7,10] 9 0 2 = Just 2`
464 | ||| `findBinBetween [2,7,10] 7 0 2 = Just 2`
465 | ||| `findBinBetween [1,2,3,4,5] 6 0 4 = Nothing`
466 | public export
467 | findBinBetween : Ord a => {n : Nat} -> (xs : Vect (S n) a) ->
468 |   (x : a) ->
469 |   (lowInd : Fin (S n)) -> (highInd : Fin (S n)) ->
470 |   {auto prf : So (highInd >= lowInd)} ->
471 |   Maybe (Fin (S n))
472 | findBinBetween xs x lowInd highInd = case x > index highInd xs of
473 |   True => Nothing -- rule out the case where x is bigger
474 |   False => case x <= index lowInd xs of
475 |     True => Just lowInd
476 |     False => let midInd = mid lowInd highInd
477 |              in case compare x (index midInd xs) of
478 |                LT => findBinBetween xs x lowInd midInd {prf = believe_me ()}
479 |                EQ => Just midInd
480 |                GT => findBinBetween xs x (finS' midInd) highInd {prf = believe_me ()}
481 |
482 | ||| Todo can this eventually be generalised to non-cubical tensors?
483 | ||| Given a non-empty sorted vector `xs`, finds the "right bin", i.e. the index 
484 | ||| of the smallest element between the bounds that `x` is not bigger than.
485 | ||| If `x` is bigger than the highest element, returns `Nothing`
486 | ||| `findBin [2,7,10] 1 = Just 0`
487 | ||| `findBin [2,7,10] 3 = Just 1`
488 | ||| `findBin [2,4,6,8] 7 = Just 3`
489 | public export
490 | findBin : Ord a => {n : Nat} ->
491 |   (xs : Vect (S n) a) -> (x : a) -> Maybe (Fin (S n))
492 | findBin {n = 0} (x' :: []) x = case x' <= x of
493 |   True => Just FZ
494 |   False => Nothing
495 | findBin {n = (S k)} xs x = findBinBetween xs x 0 last
496 |
497 |
498 | -- t : Double -> Type
499 | -- t 4 = Double
500 | -- t _ = String
501 | -- 
502 | -- th : (x : Double ** t x)
503 | -- th = (4 ** 5)
504 | -- 
505 | -- thh : (x : Double) -> Show (t x)
506 | -- thh x = ?thh_rhs
507 |
508 | public export
509 | mkDepPairShow : Show a => (ss : (x : a) -> Show (b x)) => (DPair a b -> String)
510 | mkDepPairShow = \(x ** y=> "\{show x} ** \{show (y)}"
511 |
512 | public export
513 | Show a => ((x : a) -> Show (b x)) => Show (DPair a b) where
514 |    show = mkDepPairShow
515 |
516 | -- public export
517 | -- Num Unit where
518 | --   fromInteger _ = ()
519 | --   () * () = ()
520 | --   () + () = ()
521 |
522 | public export
523 | runIf: HasIO io => Bool -> io () -> io ()
524 | runIf True action = action
525 | runIf False action = pure ()
526 |
527 | public export
528 | pairIO : HasIO io => io a -> io b -> io (a, b)
529 | pairIO a b = do
530 |   a <- a
531 |   b <- b
532 |   pure (a, b)
533 |
534 |
535 | namespace RandomUtils
536 | -- Probably there's a faster way to do this
537 | -- public export
538 | -- {n : Nat} -> Random a => Random (Vect n a) where
539 | --   randomIO = sequence $ replicate n randomIO
540 | --   randomRIO (lo, hi) = sequence $ zipWith (\l, h => randomRIO (l, h)) lo hi
541 |
542 |   public export
543 |   Random Unit where
544 |     randomIO = pure ()
545 |     randomRIO _ = pure ()
546 |
547 |   public export
548 |   Random a => Random b => Random (a, b) where
549 |     randomIO = pairIO randomIO randomIO
550 |     randomRIO ((loA, loB), (hiA, hiB))
551 |       = pairIO (randomRIO (loA, hiA)) (randomRIO (loB, hiB))
552 |
553 |
554 |
555 | -- for reshaping a tensor
556 | rrrrr : {n, x, y : Nat}
557 |   -> Fin (S n)
558 |   -> {auto prf : n = x * y}
559 |   -> (Fin (S x), Fin (S y))
560 |   -- -> Data.Fin.Arith.(*) (Fin (S x)) (Fin (S y))
561 |
562 |
563 | ||| There is a similar function in Data.Fin.Arith, which has the smallest
564 | ||| possible bound. This one does not, but has a simpler type signature.
565 | public export
566 | multFin : {m, n : Nat} -> Fin m -> Fin n -> Fin (m * n)
567 | multFin {n = (S _)} FZ y = FZ
568 | multFin {n = (S _)} (FS x) y = y + weaken (multFin x y)
569 |
570 | ||| Splits xs at each occurence of delimeter (general version for lists)
571 | public export
572 | splitList : Eq a =>
573 |   (xs : List a) -> (delimeter : List a) -> (n : Nat ** Vect n (List a))
574 | splitList xs delimeter = 
575 |   if delimeter == []
576 |     then (1 ** [xs]-- Empty delimiter returns original list
577 |     else case isInfixOfList delimeter xs of
578 |       False => (1 ** [xs]-- Delimiter not found, return original list
579 |       True => 
580 |         let (before, after) = breakOnList delimeter xs
581 |         in case after of
582 |           [] => (1 ** [before]-- No more occurrences
583 |           _  => let (restCount ** restVect= splitList (drop (length delimeter) after) delimeter
584 |                 in (S restCount ** before :: restVect)
585 |   where
586 |     -- Check if list starts with delimiter
587 |     isPrefixOfList : List a -> List a -> Bool
588 |     isPrefixOfList [] _ = True
589 |     isPrefixOfList _ [] = False
590 |     isPrefixOfList (d :: ds) (x :: xs) = d == x && isPrefixOfList ds xs
591 |     
592 |     -- Check if delimiter occurs anywhere in the list
593 |     isInfixOfList : List a -> List a -> Bool
594 |     isInfixOfList del [] = del == []
595 |     isInfixOfList del xs@(_ :: xs') = 
596 |       isPrefixOfList del xs || isInfixOfList del xs'
597 |     
598 |     -- Break list at first occurrence of delimiter
599 |     breakOnList : List a -> List a -> (List a, List a)
600 |     breakOnList del xs = breakOnListAcc del xs []
601 |       where
602 |         breakOnListAcc : List a -> List a -> List a -> (List a, List a)
603 |         breakOnListAcc del remaining acc = 
604 |           case isPrefixOfList del remaining of
605 |             True => (reverse acc, remaining)
606 |             False => case remaining of
607 |               [] => (reverse acc, [])
608 |               (c :: cs) => breakOnListAcc del cs (c :: acc)
609 |
610 | ||| Splits xs at each occurence of delimeter (string version)
611 | public export
612 | splitString : (xs : String) -> (delimeter : String) -> (n : Nat ** Vect n String)
613 | splitString xs delimeter = 
614 |   let (n ** result= splitList (unpack xs) (unpack delimeter)
615 |   in (n ** pack <$> result)
616 |
617 | ||| Simple string replacement function
618 | public export
619 | replaceString : String -> String -> String -> String
620 | replaceString old new str = 
621 |   let chars = unpack str
622 |       oldChars = unpack old
623 |       newChars = unpack new
624 |   in pack (replaceInList oldChars newChars chars)
625 |   where
626 |     replaceInList : List Char -> List Char -> List Char -> List Char
627 |     replaceInList [] _ xs = xs
628 |     replaceInList old new [] = []
629 |     replaceInList old new xs@(x :: rest) =
630 |       if isPrefixOf old xs
631 |         then new ++ replaceInList old new (drop (length old) xs)
632 |         else x :: replaceInList old new rest
633 |
634 |
635 | public export
636 | constUnit : a -> Unit
637 | constUnit _ = ()
638 |
639 | public export
640 | const2Unit : a -> b -> Unit
641 | const2Unit _ _ = ()
642 |
643 | public export
644 | fromBool : Num a => Bool -> a
645 | fromBool False = fromInteger 0
646 | fromBool True = fromInteger 1
647 |
648 | public export
649 | applyWhen : Bool -> (a -> a) -> a -> a
650 | applyWhen False f a = a
651 | applyWhen True f a = f a
652 |
653 |
654 | namespace All
655 |   namespace Vect
656 |     public export
657 |     rewriteAllMap : {xs : Vect n a} ->
658 |       All p (f <$> xs) ->
659 |       All (p . f) xs
660 |     rewriteAllMap {xs = []} [] = []
661 |     rewriteAllMap {xs = (x :: xs)} (a :: as) = a :: rewriteAllMap as
662 |
663 |     public export
664 |     rewriteAllMap' : {xs : Vect n a} ->
665 |       All (p . f) xs ->
666 |       All p (f <$> xs)
667 |     rewriteAllMap' {xs = []} [] = []
668 |     rewriteAllMap' {xs = (x :: xs)} (a :: as) = a :: rewriteAllMap' as
669 |   
670 |   namespace List
671 |     public export
672 |     rewriteAllMap : {xs : List a} ->
673 |       All p (f <$> xs) ->
674 |       All (p . f) xs
675 |     rewriteAllMap {xs = []} [] = []
676 |     rewriteAllMap {xs = (x :: xs)} (a :: as) = a :: rewriteAllMap as
677 |
678 |   ||| Cnvert an all to a vector if it's made out of replicated things
679 |   public export
680 |   allToVect : Vect.Quantifiers.All.All p (replicate n a) -> Vect n (p a)
681 |   allToVect [] = []
682 |   allToVect (aa :: aaps) = aa :: allToVect aaps
683 |
684 |   public export
685 |   constantToVect : {xs : Vect n a} ->
686 |     Vect.Quantifiers.All.All (const b) xs -> Vect n b
687 |   constantToVect [] = []
688 |   constantToVect (bb :: bbs) = bb :: constantToVect bbs
689 |
690 |
691 | ||| Dependent parametric traverse
692 | public export
693 | dTraverse : Applicative f =>
694 |   ((p : pType) -> f (q p)) ->
695 |   (xs : Vect n pType) ->
696 |   f (All q xs)
697 | dTraverse f [] = pure []
698 | dTraverse f (p :: ps) = [| f p :: dTraverse f ps |]
699 |
700 |
701 | public export
702 | record Iso (a, b : Type) where
703 |   constructor MkIso
704 |   forward : a -> b
705 |   backward : b -> a
706 |   forwardBackward : (: a) -> backward (forward x) = x
707 |   backwardForward : (: b) -> forward (backward y) = y
708 |
709 | public export
710 | multSucc : {m, n : Nat} -> IsSucc m -> IsSucc n -> IsSucc (m * n)
711 | multSucc {m = S m'} {n = S n'} ItIsSucc ItIsSucc = ItIsSucc
712 |
713 | public export
714 | allSuccThenProdSucc : (xs : List Nat) -> {auto ps : All IsSucc xs} -> IsSucc (prod xs)
715 | allSuccThenProdSucc [] {ps = []} = ItIsSucc
716 | allSuccThenProdSucc (_ :: xs') {ps = p :: _} = multSucc p (allSuccThenProdSucc xs')
717 |
718 |
719 | public export
720 | updateAt : Eq a => (a -> b) -> (a, b) -> (a -> b)
721 | updateAt f (i, val) i' = if i == i' then val else f i'
722 |
723 | ||| Graph of a dependent function
724 | public export
725 | graph : {t : a -> Type} ->
726 |   (g : (x : a) -> t x) ->
727 |   a -> (x : a ** t x)
728 | graph g x = (x ** g x)
729 |
730 | ||| Version of `map` for dependent function
731 | ||| Note that here `x : a` is identity in some sense, it comes from `f a`
732 | public export
733 | dependentMap : Functor f => {t : a -> Type} ->
734 |   (g : (x : a) -> t x) ->
735 |   f a -> f (x : a ** t x)
736 | dependentMap g fa = map (graph g) fa
737 |
738 |
739 | -- ||| Duplicate of `index` from Data.Vect.Quantifiers.All, but with an
740 | -- ||| additional `public` export modifier
741 | public export
742 | index : (i : Fin k) -> Vect.Quantifiers.All.All p ts -> p (Vect.index i ts)
743 | index FZ (x :: xs) = x
744 | index (FS j) (x :: xs) = index j xs
745 |
746 |
747 | {-
748 |
749 | interface Comult (f : Type -> Type) a where
750 |   comult : f a -> f (f a)
751 |
752 | {shape : Vect n Nat} -> Num a => Comult (TensorA shape) a where
753 |   comult t = ?eir
754 |
755 | gg : TensorA [3] Double -> TensorA [3, 3] Double
756 | gg (TS xs) = TS $ map ?fn ?gg_rhs_0
757 |
758 | -- [1, 2, 3]
759 | -- can we even do outer product?
760 | -- we wouldn't need reduce, but something like multiply?
761 | outer : {f : Type -> Type} -> {a : Type}
762 |   -> (Num a, Applicative f, Algebra f a)
763 |   => f a -> f a -> f (f a)
764 | outer xs ys = let t = liftA2 xs ys
765 |               in ?outer_rhs 
766 |   
767 |  -}
768 |
769 | |||| filter' works without `with`?
770 | filter' : (a -> Bool) -> Vect n a -> (p ** Vect p a)
771 | filter' p [] = (0 ** [])
772 | filter' p (x :: xs) = case filter' p xs of 
773 |   (_ ** xs'=> if p x then (_ ** x :: xs'else (_ ** xs')
774 |
775 | ||| filter'' implemented with `with`
776 | filter'' : (a -> Bool) -> Vect n a -> (p ** Vect p a)
777 | filter'' p [] = (0 ** [])
778 | filter'' p (x :: xs) with (filter' p xs)
779 |   _ | (_ ** xs'= if p x then (_ ** x :: xs'else (_ ** xs')
780 |
781 | {-
782 | Prelude.absurd : Uninhabited t => t -> a
783 | believe_me : a -> b
784 |
785 | -}
786 |
787 |
788 |
789 |
790 |
791 | namespace Linearity
792 |   ll1 : {n : Nat} -> Vect n a -> Nat
793 |   ll1 {n} _ = n
794 |   
795 |   -- Should this be detected as `using` the variable `n`?
796 |   -- in pattern matching, we'd have to unify type of `xs` which has in itself `len`
797 |   -- and `n` which in this case is computed to be `S len`?
798 |   -- this step of `ll2` is decomposing `n` only one level down, but the entire recursion ends up using the entire `n`
799 |   ll2 : {0 n : Nat} -> Vect n a -> Nat
800 |   ll2 [] = 0
801 |   ll2 {n=S t} (x :: xs) = 1 + ll2 xs
802 |
803 |
804 |
805 | public export
806 | testFun : Nat -> (m : Nat ** Vect m Nat)
807 |
808 | testFun2 : Nat -> Vect m Nat
809 |
810 | consume : Vect m a -> Type
811 |
812 | composed : (p : a -> Bool) ->
813 |   (xs : Vect n a) ->
814 |   consume (snd (filter p xs))
815 | composed p xs = ?composed_rhs
816 |
817 | -- public export
818 | -- filter : (elem -> Bool) -> Vect len elem -> (p ** Vect p elem)
819 | -- filter p []      = ( _ ** [] )
820 | -- filter p (x::xs) =
821 | --   let (_ ** tail) = filter p xs
822 | --    in if p x then
823 | --         (_ ** x::tail)
824 | --       else
825 | --         (_ ** tail)
826 |
827 | public export
828 | filter2 : (a -> Bool) -> Vect len a -> Vect p a
829 | filter2 f xs = ?filter2_rhs
830 |