0 | module Data.NumIdr.Array.Array
5 | import Data.Permutation
6 | import Data.NumIdr.Interfaces
7 | import Data.NumIdr.PrimArray
8 | import Data.NumIdr.Array.Rep
9 | import Data.NumIdr.Array.Coords
32 | data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
38 | MkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
39 | PrimArray rep s a @{rc} -> Array s a
45 | unsafeMkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
46 | PrimArray rep s a @{rc} -> Array s a
47 | unsafeMkArray = MkArray
57 | shape : Array {rk} s a -> Vect rk Nat
58 | shape (MkArray _ s _) = s
62 | size : Array s a -> Nat
63 | size = product . shape
67 | rank : Array s a -> Nat
68 | rank = length . shape
72 | getRep : Array s a -> Rep
73 | getRep (MkArray rep _ _) = rep
77 | getRepC : (arr : Array s a) -> RepConstraint (getRep arr) a
78 | getRepC (MkArray _ @{rc} _ _) = rc
82 | getPrim : (arr : Array s a) -> PrimArray (getRep arr) s a @{getRepC arr}
83 | getPrim (MkArray _ _ pr) = pr
91 | shapeEq : (arr : Array s a) -> s = shape arr
92 | shapeEq (MkArray _ _ _) = Refl
97 | data ShapeView : Array s a -> Type where
98 | Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
103 | viewShape : (arr : Array s a) -> ShapeView arr
104 | viewShape arr = rewrite shapeEq arr in
106 | {arr = rewrite sym (shapeEq arr) in arr}
118 | repeat : {default B rep : Rep} -> RepConstraint rep a =>
119 | (s : Vect rk Nat) -> a -> Array s a
120 | repeat s x = MkArray rep s (PrimArray.constant s x)
127 | zeros : {default B rep : Rep} -> RepConstraint rep a =>
128 | Num a => (s : Vect rk Nat) -> Array s a
129 | zeros {rep} s = repeat {rep} s 0
136 | ones : {default B rep : Rep} -> RepConstraint rep a =>
137 | Num a => (s : Vect rk Nat) -> Array s a
138 | ones {rep} s = repeat {rep} s 1
146 | fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
147 | (s : Vect rk Nat) -> Vect (product s) a -> Array s a
148 | fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $
toList v)
156 | fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
157 | (s : Vect rk Nat) -> Stream a -> Array s a
158 | fromStream {rep} s str = fromVect {rep} s (take _ str)
165 | fromFunctionNB : {default B rep : Rep} -> RepConstraint rep a =>
166 | (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
167 | fromFunctionNB s f = MkArray rep s (PrimArray.fromFunctionNB s f)
174 | fromFunction : {default B rep : Rep} -> RepConstraint rep a =>
175 | (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
176 | fromFunction s f = MkArray rep s (PrimArray.fromFunction s f)
181 | array : {default B rep : Rep} -> RepConstraint rep a =>
182 | {s : Vect rk Nat} -> Vects s a -> Array s a
183 | array v = MkArray rep s (fromVects s v)
200 | index : Coords s -> Array s a -> a
201 | index is (MkArray _ _ arr) = PrimArray.index is arr
207 | (!!) : Array s a -> Coords s -> a
208 | arr !! is = index is arr
212 | indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
213 | indexUpdate is f (MkArray rep @{rc} s arr) =
214 | MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
218 | indexSet : Coords s -> a -> Array s a -> Array s a
219 | indexSet is = indexUpdate is . const
224 | indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
225 | indexRange rs (MkArray rep @{rc} s arr) =
226 | MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
232 | (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
233 | arr !!.. rs = indexRange rs arr
237 | indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
238 | Array s a -> Array s a
239 | indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
240 | MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRepPrim rpl) arr)
246 | indexUpdateRange : (rs : CoordsRange s) ->
247 | (Array (newShape rs) a -> Array (newShape rs) a) ->
248 | Array s a -> Array s a
249 | indexUpdateRange rs f arr = indexSetRange rs (f $
arr !!.. rs) arr
255 | indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
256 | indexNB is (MkArray rep @{rc} s arr) =
257 | map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
264 | (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
265 | arr !? is = indexNB is arr
270 | indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
271 | indexUpdateNB is f (MkArray rep @{rc} s arr) =
272 | map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
277 | indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Maybe (Array s a)
278 | indexSetNB is = indexUpdateNB is . const
284 | indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
285 | indexRangeNB rs (MkArray rep @{rc} s arr) =
286 | map (\rs => believe_me $
Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
293 | (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
294 | arr !?.. rs = indexRangeNB rs arr
301 | indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
302 | indexUnsafe is (MkArray _ _ arr) = PrimArray.indexUnsafe is arr
309 | export %inline %unsafe
310 | (!#) : Array {rk} s a -> Vect rk Nat -> a
311 | arr !# is = indexUnsafe is arr
318 | indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
319 | indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
320 | believe_me $
Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
327 | export %inline %unsafe
328 | (!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
329 | arr !#.. is = indexRangeUnsafe is arr
342 | mapArray' : (a -> a) -> Array s a -> Array s a
343 | mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
350 | mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
351 | mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
358 | zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
359 | zipWithArray' {s} f a b with (viewShape a)
360 | _ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
361 | @{mergeRepConstraint (getRepC a) (getRepC b)} s
362 | $
PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
363 | (\is => f (a !# is) (b !# is))
370 | zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
371 | RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
372 | zipWithArray {s} f a b @{rc} with (viewShape a)
373 | _ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
374 | $
PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
381 | reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
382 | (0 ok : product s = product s') => Array s' a
383 | reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
387 | convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
388 | convertRep rep (MkArray _ s arr) = MkArray rep s (convertRepPrim arr)
393 | delayed : (Array s a -> Array s' a) -> Array s a -> Array s' a
394 | delayed f arr = convertRep (getRep arr) @{getRepC arr} $
f $
convertRep Delayed arr
402 | resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
403 | resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
413 | resizeLTE : (s' : Vect rk Nat) -> (0 ok : All Prelude.id (zipWith LTE s' s)) =>
414 | Array {rk} s a -> Array s' a
415 | resizeLTE s' arr = resize s' (believe_me ()) arr
420 | enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
421 | enumerateNB (MkArray _ s arr) =
422 | map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
426 | enumerate : Array s a -> List (Coords s, a)
427 | enumerate {s} arr with (viewShape arr)
428 | _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
432 | elements : Array {rk} s a -> Vect (product s) a
433 | elements (MkArray _ s arr) =
434 | believe_me $
Vect.fromList $
435 | map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
443 | concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
444 | Array (updateAt axis (+d) s) a
445 | concat {s,d} axis a b with (viewShape a, viewShape b)
446 | _ | (Shape s, Shape (replaceAt axis d s)) =
447 | believe_me $
Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
448 | @{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
449 | (\is => let limit = index axis s
450 | in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
456 | stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
457 | stack axis arrs = rewrite sym (lengthCorrect arrs) in
458 | fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
459 | (i,is') => index is' (index i arrs))
461 | getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
462 | getAxisInd FZ (i :: is) = (i, is)
463 | getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
467 | joinAxes : {s' : _} -> Array s (Array s' a) -> Array (s ++ s') a
468 | joinAxes {s} arr with (viewShape arr)
469 | _ | Shape s = fromFunctionNB (s ++ s') (\is => arr !# takeUpTo s is !# dropUpTo s is)
471 | takeUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk Nat
472 | takeUpTo [] ys = []
473 | takeUpTo (x::xs) (y::ys) = y :: takeUpTo xs ys
475 | dropUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk' Nat
476 | dropUpTo [] ys = ys
477 | dropUpTo (x::xs) (y::ys) = dropUpTo xs ys
482 | splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
483 | Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
484 | splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
488 | transpose : Array s a -> Array (reverse s) a
489 | transpose {s} arr with (viewShape arr)
490 | _ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
496 | (.T) : Array s a -> Array (reverse s) a
502 | swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a
503 | swapAxes {s} i j arr with (viewShape arr)
504 | _ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is)
508 | permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a
509 | permuteAxes {s} p arr with (viewShape arr)
510 | _ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s)
517 | swapInAxis : (axis : Fin rk) -> (i,j : Fin (index axis s)) -> Array s a -> Array s a
518 | swapInAxis {s} axis i j arr with (viewShape arr)
519 | _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (swapValues i j) is)
527 | permuteInAxis : (axis : Fin rk) -> Permutation (index axis s) -> Array s a -> Array s a
528 | permuteInAxis {s} axis p arr with (viewShape arr)
529 | _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (permuteValues p) is)
537 | Zippable (Array s) where
538 | zipWith {s} f a b with (viewShape a)
539 | _ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
540 | @{mergeNCRepConstraint} s
541 | $
PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
542 | (\is => f (a !# is) (b !# is))
544 | zipWith3 {s} f a b c with (viewShape a)
545 | _ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
546 | @{mergeNCRepConstraint} s
547 | $
PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
548 | (\is => f (a !# is) (b !# is) (c !# is))
550 | unzipWith {s} f arr with (viewShape arr)
553 | rep = forceRepNC $
getRep arr
555 | @{forceRepConstraint} s
556 | $
PrimArray.fromFunctionNB @{forceRepConstraint} _
557 | (\is => fst $
f (arr !# is)),
559 | @{forceRepConstraint} s
560 | $
PrimArray.fromFunctionNB @{forceRepConstraint} _
561 | (\is => snd $
f (arr !# is)))
563 | unzipWith3 {s} f arr with (viewShape arr)
566 | rep = forceRepNC $
getRep arr
568 | @{forceRepConstraint} s
569 | $
PrimArray.fromFunctionNB @{forceRepConstraint} _
570 | (\is => fst $
f (arr !# is)),
572 | @{forceRepConstraint} s
573 | $
PrimArray.fromFunctionNB @{forceRepConstraint} _
574 | (\is => fst $
snd $
f (arr !# is)),
576 | @{forceRepConstraint} s
577 | $
PrimArray.fromFunctionNB @{forceRepConstraint} _
578 | (\is => snd $
snd $
f (arr !# is)))
582 | Functor (Array s) where
583 | map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
584 | (mapPrim @{forceRepConstraint} @{forceRepConstraint} f
585 | $
convertRepPrim @{rc} @{forceRepConstraint} arr)
588 | {s : _} -> Applicative (Array s) where
590 | (<*>) = zipWith apply
593 | {s : _} -> Monad (Array s) where
594 | join arr = fromFunction s (\is => arr !! is !! is)
602 | Foldable (Array s) where
603 | foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
604 | foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
605 | null (MkArray _ s _) = isZero (product s)
609 | Traversable (Array s) where
610 | traverse f (MkArray rep @{rc} s arr) =
611 | map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
612 | (PrimArray.traverse {rep=forceRepNC rep}
613 | @{%search} @{forceRepConstraint} @{forceRepConstraint} f
614 | (convertRepPrim @{rc} @{forceRepConstraint} arr))
618 | Cast a b => Cast (Array s a) (Array s b) where
622 | Eq a => Eq (Array s a) where
623 | a == b = and $
zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
626 | Semigroup a => Semigroup (Array s a) where
627 | (<+>) = zipWithArray' (<+>)
630 | {s : _} -> Monoid a => Monoid (Array s a) where
631 | neutral = repeat s neutral
636 | {s : _} -> Num a => Num (Array s a) where
637 | (+) = zipWithArray' (+)
638 | (*) = zipWithArray' (*)
640 | fromInteger = repeat s . fromInteger
643 | {s : _} -> Neg a => Neg (Array s a) where
644 | negate = mapArray' negate
645 | (-) = zipWithArray' (-)
648 | {s : _} -> Fractional a => Fractional (Array s a) where
649 | recip = mapArray' recip
650 | (/) = zipWithArray' (/)
654 | Num a => Mult a (Array {rk} s a) (Array s a) where
655 | (*.) x = mapArray' (*x)
658 | Num a => Mult (Array {rk} s a) a (Array s a) where
663 | Show a => Show (Array s a) where
664 | showPrec d arr = let orderedElems = toList $
elements arr
665 | in showCon d "array " $
concat $
insertPunct (shape arr) $
map show orderedElems
667 | splitWindow : Nat -> List String -> List (List String)
668 | splitWindow n xs = case splitAt n xs of
670 | (l1, l2) => l1 :: splitWindow n (assert_smaller xs l2)
672 | insertPunct : Vect rk Nat -> List String -> List String
673 | insertPunct [] strs = strs
674 | insertPunct [d] strs = "[" :: intersperse ", " strs `snoc` "]"
675 | insertPunct (Z :: s) strs = ["[","]"]
676 | insertPunct (d :: s) strs =
677 | let secs = if null strs
678 | then List.replicate d ("[]" :: Prelude.Nil)
679 | else map (insertPunct s) $
splitWindow (length strs `div` d) strs
680 | in "[" :: (concat $
intersperse [", "] secs) `snoc` "]"
690 | lerp : Neg a => a -> Array s a -> Array s a -> Array s a
691 | lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
696 | normSq : Num a => Array s a -> a
697 | normSq arr = sum $
zipWith (*) arr arr
701 | norm : Array s Double -> Double
702 | norm = sqrt . normSq
708 | normalize : Array s Double -> Array s Double
709 | normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
713 | pnorm : (p : Double) -> Array s Double -> Double
714 | pnorm p = (`pow` recip p) . sum . map (`pow` p)