0 | module Data.NumIdr.Array.Coords
  1 |
  2 | import Data.Either
  3 | import Data.List
  4 | import Data.List1
  5 | import Data.Vect
  6 |
  7 | import public Data.Vect.Quantifiers
  8 |
  9 | %default total
 10 |
 11 |
 12 | -- A Nat-based range function with better semantics
 13 | public export
 14 | range : Nat -> Nat -> List Nat
 15 | range x y = if x < y then assert_total $ takeBefore (>= y) (countFrom x S)
 16 |                      else []
 17 |
 18 | -- helpful theorems for working with ranges
 19 |
 20 | export
 21 | rangeLen : (x,y : Nat) -> length (range x y) = minus y x
 22 | rangeLen x y = believe_me $ Refl {x = minus y x}
 23 |
 24 | export
 25 | rangeLenZ : (x : Nat) -> length (range 0 x) = x
 26 | rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
 27 |
 28 | export %unsafe
 29 | assertFin : Nat -> Fin n
 30 | assertFin n = natToFinLt n @{believe_me Oh}
 31 |
 32 | --------------------------------------------------------------------------------
 33 | -- Array coordinate types
 34 | --------------------------------------------------------------------------------
 35 |
 36 |
 37 | ||| A type-safe coordinate system for an array. The coordinates are
 38 | ||| values of `Fin dim`, where `dim` is the dimension of each axis.
 39 | public export
 40 | Coords : (s : Vect rk Nat) -> Type
 41 | Coords = All Fin
 42 |
 43 |
 44 | -- Occasionally necessary for reasons of Idris not being great at
 45 | -- resolving interface constraints
 46 | public export
 47 | [eqCoords] Eq (Coords s) where
 48 |   [] == [] = True
 49 |   (x :: xs) == (y :: ys) = x == y && xs == ys
 50 |
 51 |
 52 | ||| Forget the shape of the array by converting each index to type `Nat`.
 53 | export
 54 | toNB : Coords {rk} s -> Vect rk Nat
 55 | toNB [] = []
 56 | toNB (i :: is) = finToNat i :: toNB is
 57 |
 58 | export
 59 | validateCoords : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
 60 | validateCoords [] [] = Just []
 61 | validateCoords (d :: s) (i :: is) = (::) <$> natToFin i d <*> validateCoords s is
 62 |
 63 |
 64 | namespace Strict
 65 |   public export
 66 |   data CRange : Nat -> Type where
 67 |     One : Fin n -> CRange n
 68 |     One' : Fin n -> CRange n
 69 |     All : CRange n
 70 |     StartBound : Fin (S n) -> CRange n
 71 |     EndBound : Fin (S n) -> CRange n
 72 |     Bounds : Fin (S n) -> Fin (S n) -> CRange n
 73 |     Indices : List (Fin n) -> CRange n
 74 |     Filter : (Fin n -> Bool) -> CRange n
 75 |
 76 |
 77 |   public export
 78 |   CoordsRange : (s : Vect rk Nat) -> Type
 79 |   CoordsRange = All CRange
 80 |
 81 |
 82 | namespace NB
 83 |   public export
 84 |   data CRangeNB : Type where
 85 |     One : Nat -> CRangeNB
 86 |     One' : Nat -> CRangeNB
 87 |     All : CRangeNB
 88 |     StartBound : Nat -> CRangeNB
 89 |     EndBound : Nat -> CRangeNB
 90 |     Bounds : Nat -> Nat -> CRangeNB
 91 |     Indices : List Nat -> CRangeNB
 92 |     Filter : (Nat -> Bool) -> CRangeNB
 93 |
 94 |
 95 | --------------------------------------------------------------------------------
 96 | -- Indexing helper functions
 97 | --------------------------------------------------------------------------------
 98 |
 99 |
100 | public export
101 | Vects : Vect rk Nat -> Type -> Type
102 | Vects []     a = a
103 | Vects (d::s) a = Vect d (Vects s a)
104 |
105 | export
106 | collapse : {s : _} -> Vects s a -> List a
107 | collapse {s=[]} = pure
108 | collapse {s=_::_} = concat . map collapse
109 |
110 |
111 | export
112 | mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects s a -> Vects s b
113 | mapWithIndex {s=[]}   f x = f [] x
114 | mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
115 |   where
116 |     mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
117 |     mapWithIndex' f [] = []
118 |     mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
119 |
120 |
121 | export
122 | getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
123 | getLocation' = sum .: zipWith (*)
124 |
125 | ||| Compute the memory location of an array element
126 | ||| given its coordinate and the strides of the array.
127 | export
128 | getLocation : Vect rk Nat -> Coords {rk} s -> Nat
129 | getLocation sts is = getLocation' sts (toNB is)
130 |
131 |
132 | namespace Strict
133 |   public export
134 |   cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat)
135 |   cRangeToList (One x) = Left (cast x)
136 |   cRangeToList (One' x) = Right [cast x]
137 |   cRangeToList All = Right $ range 0 n
138 |   cRangeToList (StartBound x) = Right $ range (cast x) n
139 |   cRangeToList (EndBound x) = Right $ range 0 (cast x)
140 |   cRangeToList (Bounds x y) = Right $ range (cast x) (cast y)
141 |   cRangeToList (Indices xs) = Right $ map cast $ nub xs
142 |   cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range
143 |
144 |
145 |   public export
146 |   newRank : {s : _} -> CoordsRange s -> Nat
147 |   newRank [] = 0
148 |   newRank (r :: rs) = case cRangeToList r of
149 |                         Left _ => newRank rs
150 |                         Right _ => S (newRank rs)
151 |
152 |   ||| Calculate the new shape given by a coordinate range.
153 |   public export
154 |   newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
155 |   newShape [] = []
156 |   newShape (r :: rs) with (cRangeToList r)
157 |     _ | Left _ = newShape rs
158 |     _ | Right xs = length xs :: newShape rs
159 |
160 |
161 |   getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
162 |   getNewPos [] [] = []
163 |   getNewPos (r :: rs) (i :: is) with (cRangeToList r)
164 |     _ | Left _ = getNewPos rs is
165 |     _ | Right xs = cast (assert_total $ case findIndex (==i) xs of Just x => x)
166 |                     :: getNewPos rs is
167 |
168 |   export
169 |   getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
170 |   getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
171 |     where
172 |       go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
173 |       go [] = [[]]
174 |       go (r :: rs) = [| (either pure id (cRangeToList r)) :: go rs |]
175 |
176 |
177 | namespace NB
178 |   export
179 |   validateCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> Maybe (CoordsRange s)
180 |   validateCRange [] [] = Just []
181 |   validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |]
182 |     where
183 |       validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n)
184 |       validate' n (One i) =
185 |         case isLT i n of
186 |           Yes _ => Just (One (natToFinLT i))
187 |           _ => Nothing
188 |       validate' n (One' i) =
189 |         case isLT i n of
190 |           Yes _ => Just (One' (natToFinLT i))
191 |           _ => Nothing
192 |       validate' n All = Just All
193 |       validate' n (StartBound x) =
194 |         case isLTE x n of
195 |           Yes _ => Just (StartBound (natToFinLT x))
196 |           _ => Nothing
197 |       validate' n (EndBound x) =
198 |         case isLTE x n of
199 |           Yes _ => Just (EndBound (natToFinLT x))
200 |           _ => Nothing
201 |       validate' n (Bounds x y) =
202 |         case (isLTE x n, isLTE y n) of
203 |           (Yes _, Yes _) => Just (Bounds (natToFinLT x) (natToFinLT y))
204 |           _ => Nothing
205 |       validate' n (Indices xs) = Indices <$> traverse
206 |         (\x => case isLT x n of
207 |                 Yes _ => Just (natToFinLT x)
208 |                 No _ => Nothing) xs
209 |       validate' n (Filter f) = Just (Filter (f . finToNat))
210 |
211 |   export %unsafe
212 |   assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s
213 |   assertCRange [] [] = []
214 |   assertCRange (d :: s) (r :: rs) = assert' r :: assertCRange s rs
215 |     where
216 |       assert' : forall n. CRangeNB -> CRange n
217 |       assert' (One i) = One (assertFin i)
218 |       assert' (One' i) = One' (assertFin i)
219 |       assert' All = All
220 |       assert' (StartBound x) = StartBound (assertFin x)
221 |       assert' (EndBound x) = EndBound (assertFin x)
222 |       assert' (Bounds x y) = Bounds (assertFin x) (assertFin y)
223 |       assert' (Indices xs) = Indices (assertFin <$> xs)
224 |       assert' (Filter f) = Filter (f . finToNat)
225 |
226 |   public export
227 |   cRangeNBToList : Nat -> CRangeNB -> Either Nat (List Nat)
228 |   cRangeNBToList s (One i) = Left i
229 |   cRangeNBToList s (One' i) = Right [i]
230 |   cRangeNBToList s All = Right $ range 0 s
231 |   cRangeNBToList s (StartBound x) = Right $ range x s
232 |   cRangeNBToList s (EndBound x) = Right $ range 0 x
233 |   cRangeNBToList s (Bounds x y) = Right $ range x y
234 |   cRangeNBToList s (Indices xs) = Right $ nub xs
235 |   cRangeNBToList s (Filter p) = Right $ filter p $ range 0 s
236 |
237 |   public export
238 |   newRank : Vect rk Nat -> Vect rk CRangeNB -> Nat
239 |   newRank _ [] = 0
240 |   newRank (d :: s) (r :: rs) =
241 |     case cRangeNBToList d r of
242 |       Left _ => newRank s rs
243 |       Right _ => S (newRank s rs)
244 |
245 |   ||| Calculate the new shape given by a coordinate range.
246 |   public export
247 |   newShape : (s : Vect rk Nat) -> (is : Vect rk CRangeNB) -> Vect (newRank s is) Nat
248 |   newShape [] [] = []
249 |   newShape (d :: s) (r :: rs) with (cRangeNBToList d r)
250 |     _ | Left _ = newShape s rs
251 |     _ | Right xs = length xs :: newShape s rs
252 |
253 | export
254 | getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
255 | getAllCoords' = traverse (\case Z => []S n => [0..n])
256 |
257 | export
258 | getAllCoords : (s : Vect rk Nat) -> List (Coords s)
259 | getAllCoords [] = [[]]
260 | getAllCoords (Z :: s) = []
261 | getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
262 |