0 | module Data.NumIdr.Array.Array
  1 |
  2 | import Data.List
  3 | import Data.Vect
  4 | import Data.Zippable
  5 | import Data.Permutation
  6 | import Data.NumIdr.Interfaces
  7 | import Data.NumIdr.PrimArray
  8 | import Data.NumIdr.Array.Rep
  9 | import Data.NumIdr.Array.Coords
 10 |
 11 | %default total
 12 |
 13 |
 14 | ||| The type of an array.
 15 | |||
 16 | ||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
 17 | ||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
 18 | ||| of rank 0 are single values, arrays of rank 1 are vectors, and arrays of rank
 19 | ||| 2 are matrices.
 20 | |||
 21 | ||| Each array has a *shape*, which is a vector of values giving the dimensions
 22 | ||| of each axis of the array. The shape is also sometimes used to determine the
 23 | ||| array's total size.
 24 | |||
 25 | ||| Arrays are indexed by row first, as in the standard mathematical notation for
 26 | ||| matrices.
 27 | |||
 28 | ||| @ rk The rank of the array
 29 | ||| @ s The shape of the array
 30 | ||| @ a The type of the array's elements
 31 | export
 32 | data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
 33 |   ||| Internally, arrays are stored via one of a handful of representations.
 34 |   |||
 35 |   ||| @ s   The shape of the array
 36 |   ||| @ rep The internal representation of the array
 37 |   ||| @ rc  A witness that the element type satisfies the representation constraint
 38 |   MkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
 39 |                PrimArray rep s a @{rc} -> Array s a
 40 |
 41 | %name Array arr
 42 |
 43 |
 44 | export
 45 | unsafeMkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
 46 |                   PrimArray rep s a @{rc} -> Array s a
 47 | unsafeMkArray = MkArray
 48 |
 49 |
 50 | --------------------------------------------------------------------------------
 51 | -- Properties of arrays
 52 | --------------------------------------------------------------------------------
 53 |
 54 |
 55 | ||| The shape of the array.
 56 | export
 57 | shape : Array {rk} s a -> Vect rk Nat
 58 | shape (MkArray _ s _) = s
 59 |
 60 | ||| The size of the array, i.e. the total number of elements.
 61 | export
 62 | size : Array s a -> Nat
 63 | size = product . shape
 64 |
 65 | ||| The rank of the array.
 66 | export
 67 | rank : Array s a -> Nat
 68 | rank = length . shape
 69 |
 70 | ||| The internal representation of the array.
 71 | export
 72 | getRep : Array s a -> Rep
 73 | getRep (MkArray rep _ _) = rep
 74 |
 75 | ||| The representation constraint of the array.
 76 | export
 77 | getRepC : (arr : Array s a) -> RepConstraint (getRep arr) a
 78 | getRepC (MkArray _ @{rc} _ _) = rc
 79 |
 80 | ||| Extract the primitive backend array.
 81 | export
 82 | getPrim : (arr : Array s a) -> PrimArray (getRep arr) s a @{getRepC arr}
 83 | getPrim (MkArray _ _ pr) = pr
 84 |
 85 |
 86 | --------------------------------------------------------------------------------
 87 | -- Shape view
 88 | --------------------------------------------------------------------------------
 89 |
 90 | export
 91 | shapeEq : (arr : Array s a) -> s = shape arr
 92 | shapeEq (MkArray _ _ _) = Refl
 93 |
 94 |
 95 | ||| A view for extracting the shape of an array.
 96 | public export
 97 | data ShapeView : Array s a -> Type where
 98 |   Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
 99 |
100 | ||| The covering function for the view `ShapeView`. This function takes an array
101 | ||| of type `Array s a` and returns `Shape s`.
102 | export
103 | viewShape : (arr : Array s a) -> ShapeView arr
104 | viewShape arr = rewrite shapeEq arr in
105 |                   Shape (shape arr)
106 |                   {arr = rewrite sym (shapeEq arr) in arr}
107 |
108 |
109 | --------------------------------------------------------------------------------
110 | -- Array constructors
111 | --------------------------------------------------------------------------------
112 |
113 | ||| Create an array by repeating a single value.
114 | |||
115 | ||| @ s    The shape of the constructed array
116 | ||| @ rep  The internal representation of the constructed array
117 | export
118 | repeat : {default B rep : Rep} -> RepConstraint rep a =>
119 |             (s : Vect rk Nat) -> a -> Array s a
120 | repeat s x = MkArray rep s (PrimArray.constant s x)
121 |
122 | ||| Create an array filled with zeros.
123 | |||
124 | ||| @ s The shape of the constructed array
125 | ||| @ rep  The internal representation of the constructed array
126 | export
127 | zeros : {default B rep : Rep} -> RepConstraint rep a =>
128 |             Num a => (s : Vect rk Nat) -> Array s a
129 | zeros {rep} s = repeat {rep} s 0
130 |
131 | ||| Create an array filled with ones.
132 | |||
133 | ||| @ s The shape of the constructed array
134 | ||| @ rep  The internal representation of the constructed array
135 | export
136 | ones : {default B rep : Rep} -> RepConstraint rep a =>
137 |             Num a => (s : Vect rk Nat) -> Array s a
138 | ones {rep} s = repeat {rep} s 1
139 |
140 | ||| Create an array given a vector of its elements. The elements of the vector
141 | ||| are arranged into the provided shape using the provided order.
142 | |||
143 | ||| @ s   The shape of the constructed array
144 | ||| @ rep  The internal representation of the constructed array
145 | export
146 | fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
147 |               (s : Vect rk Nat) -> Vect (product s) a -> Array s a
148 | fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v)
149 |
150 |
151 | ||| Create an array by taking values from a stream.
152 | |||
153 | ||| @ s   The shape of the constructed array
154 | ||| @ rep  The internal representation of the constructed array
155 | export
156 | fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
157 |                 (s : Vect rk Nat) -> Stream a -> Array s a
158 | fromStream {rep} s str = fromVect {rep} s (take _ str)
159 |
160 | ||| Create an array given a function to generate its elements.
161 | |||
162 | ||| @ s   The shape of the constructed array
163 | ||| @ rep  The internal representation of the constructed array
164 | export
165 | fromFunctionNB : {default B rep : Rep} -> RepConstraint rep a =>
166 |                     (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
167 | fromFunctionNB s f = MkArray rep s (PrimArray.fromFunctionNB s f)
168 |
169 | ||| Create an array given a function to generate its elements.
170 | |||
171 | ||| @ s   The shape of the constructed array
172 | ||| @ rep  The internal representation of the constructed array
173 | export
174 | fromFunction : {default B rep : Rep} -> RepConstraint rep a =>
175 |                   (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
176 | fromFunction s f = MkArray rep s (PrimArray.fromFunction s f)
177 |
178 | ||| Construct an array using a structure of nested vectors.
179 | ||| To explicitly specify the shape and order of the array, use `array'`.
180 | export
181 | array : {default B rep : Rep} -> RepConstraint rep a =>
182 |           {s : Vect rk Nat} -> Vects s a -> Array s a
183 | array v = MkArray rep s (fromVects s v)
184 |
185 |
186 | --------------------------------------------------------------------------------
187 | -- Indexing
188 | --------------------------------------------------------------------------------
189 |
190 | infixl 10 !!
191 | infixl 10 !?
192 | infixl 10 !#
193 | infixl 11 !!..
194 | infixl 11 !?..
195 | infixl 11 !#..
196 |
197 |
198 | ||| Index the array using the given coordinates.
199 | export
200 | index : Coords s -> Array s a -> a
201 | index is (MkArray _ _ arr) = PrimArray.index is arr
202 |
203 | ||| Index the array using the given coordinates.
204 | |||
205 | ||| This is the operator form of `index`.
206 | export %inline
207 | (!!) : Array s a -> Coords s -> a
208 | arr !! is = index is arr
209 |
210 | ||| Update the entry at the given coordinates using the function.
211 | export
212 | indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
213 | indexUpdate is f (MkArray rep @{rc} s arr) =
214 |   MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
215 |
216 | ||| Set the entry at the given coordinates to the given value.
217 | export
218 | indexSet : Coords s -> a -> Array s a -> Array s a
219 | indexSet is = indexUpdate is . const
220 |
221 |
222 | ||| Index the array using the given range of coordinates, returning a new array.
223 | export
224 | indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
225 | indexRange rs (MkArray rep @{rc} s arr) =
226 |   MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
227 |
228 | ||| Index the array using the given range of coordinates, returning a new array.
229 | |||
230 | ||| This is the operator form of `indexRange`.
231 | export %inline
232 | (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
233 | arr !!.. rs = indexRange rs arr
234 |
235 | ||| Set the sub-array at the given range of coordinates to the given array.
236 | export
237 | indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
238 |                   Array s a -> Array s a
239 | indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
240 |   MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRepPrim rpl) arr)
241 |
242 |
243 | ||| Update the sub-array at the given range of coordinates by applying
244 | ||| a function to it.
245 | export
246 | indexUpdateRange : (rs : CoordsRange s) ->
247 |                     (Array (newShape rs) a -> Array (newShape rs) a) ->
248 |                     Array s a -> Array s a
249 | indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
250 |
251 |
252 | ||| Index the array using the given coordinates, returning `Nothing` if the
253 | ||| coordinates are out of bounds.
254 | export
255 | indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
256 | indexNB is (MkArray rep @{rc} s arr) =
257 |   map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
258 |
259 | ||| Index the array using the given coordinates, returning `Nothing` if the
260 | ||| coordinates are out of bounds.
261 | |||
262 | ||| This is the operator form of `indexNB`.
263 | export %inline
264 | (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
265 | arr !? is = indexNB is arr
266 |
267 | ||| Update the entry at the given coordinates using the function. `Nothing` is
268 | ||| returned if the coordinates are out of bounds.
269 | export
270 | indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
271 | indexUpdateNB is f (MkArray rep @{rc} s arr) =
272 |   map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
273 |
274 | ||| Set the entry at the given coordinates using the function. `Nothing` is
275 | ||| returned if the coordinates are out of bounds.
276 | export
277 | indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Maybe (Array s a)
278 | indexSetNB is = indexUpdateNB is . const
279 |
280 |
281 | ||| Index the array using the given range of coordinates, returning a new array.
282 | ||| If any of the given indices are out of bounds, then `Nothing` is returned.
283 | export
284 | indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
285 | indexRangeNB rs (MkArray rep @{rc} s arr) =
286 |   map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
287 |
288 | ||| Index the array using the given range of coordinates, returning a new array.
289 | ||| If any of the given indices are out of bounds, then `Nothing` is returned.
290 | |||
291 | ||| This is the operator form of `indexRangeNB`.
292 | export %inline
293 | (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
294 | arr !?.. rs = indexRangeNB rs arr
295 |
296 |
297 | ||| Index the array using the given coordinates.
298 | ||| WARNING: This function does not perform any bounds check on its inputs.
299 | ||| Misuse of this function can easily break memory safety.
300 | export %unsafe
301 | indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
302 | indexUnsafe is (MkArray _ _ arr) = PrimArray.indexUnsafe is arr
303 |
304 | ||| Index the array using the given coordinates.
305 | ||| WARNING: This function does not perform any bounds check on its inputs.
306 | ||| Misuse of this function can easily break memory safety.
307 | |||
308 | ||| This is the operator form of `indexUnsafe`.
309 | export %inline %unsafe
310 | (!#) : Array {rk} s a -> Vect rk Nat -> a
311 | arr !# is = indexUnsafe is arr
312 |
313 |
314 | ||| Index the array using the given range of coordinates, returning a new array.
315 | ||| WARNING: This function does not perform any bounds check on its inputs.
316 | ||| Misuse of this function can easily break memory safety.
317 | export %unsafe
318 | indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
319 | indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
320 |   believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
321 |
322 | ||| Index the array using the given range of coordinates, returning a new array.
323 | ||| WARNING: This function does not perform any bounds check on its inputs.
324 | ||| Misuse of this function can easily break memory safety.
325 | |||
326 | ||| This is the operator form of `indexRangeUnsafe`.
327 | export %inline %unsafe
328 | (!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
329 | arr !#.. is = indexRangeUnsafe is arr
330 |
331 |
332 |
333 | --------------------------------------------------------------------------------
334 | -- Operations on arrays
335 | --------------------------------------------------------------------------------
336 |
337 | ||| Map a function over an array.
338 | |||
339 | ||| You should almost always use `map` instead; only use this function if you
340 | ||| know what you are doing!
341 | export
342 | mapArray' : (a -> a) -> Array s a -> Array s a
343 | mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
344 |
345 | ||| Map a function over an array.
346 | |||
347 | ||| You should almost always use `map` instead; only use this function if you
348 | ||| know what you are doing!
349 | export
350 | mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
351 | mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
352 |
353 | ||| Combine two arrays of the same element type using a binary function.
354 | |||
355 | ||| You should almost always use `zipWith` instead; only use this function if
356 | ||| you know what you are doing!
357 | export
358 | zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
359 | zipWithArray' {s} f a b with (viewShape a)
360 |   _ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
361 |                           @{mergeRepConstraint (getRepC a) (getRepC b)} s
362 |                           $ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
363 |                               (\is => f (a !# is) (b !# is))
364 |
365 | ||| Combine two arrays using a binary function.
366 | |||
367 | ||| You should almost always use `zipWith` instead; only use this function if
368 | ||| you know what you are doing!
369 | export
370 | zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
371 |                 RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
372 | zipWithArray {s} f a b @{rc} with (viewShape a)
373 |   _ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
374 |                           $ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
375 |
376 |
377 | ||| Reshape the array into the given shape.
378 | |||
379 | ||| @ s' The shape to convert the array to
380 | export
381 | reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
382 |             (0 ok : product s = product s') => Array s' a
383 | reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
384 |
385 | ||| Change the internal representation of the array's elements.
386 | export
387 | convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
388 | convertRep rep (MkArray _ s arr) = MkArray rep s (convertRepPrim arr)
389 |
390 | ||| Temporarily convert an array to a delayed representation to make modifying
391 | ||| it more efficient, then convert it back to its original representation.
392 | export
393 | delayed : (Array s a -> Array s' a) -> Array s a -> Array s' a
394 | delayed f arr = convertRep (getRep arr) @{getRepC arr} $ f $ convertRep Delayed arr
395 |
396 | ||| Resize the array to a new shape, preserving the coordinates of the original
397 | ||| elements. New coordinates are filled with a default value.
398 | |||
399 | ||| @ s'  The shape to resize the array to
400 | ||| @ def The default value to fill the array with
401 | export
402 | resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
403 | resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
404 |
405 | ||| Resize the array to a new shape, preserving the coordinates of the original
406 | ||| elements. This function requires a proof that the new shape is strictly
407 | ||| smaller than the current shape of the array.
408 | |||
409 | ||| @ s' The shape to resize the array to
410 | export
411 | -- HACK: Come up with a solution that doesn't use `believe_me` or trip over some
412 | -- weird bug in the type-checker
413 | resizeLTE : (s' : Vect rk Nat) -> (0 ok : All Prelude.id (zipWith LTE s' s)) =>
414 |             Array {rk} s a -> Array s' a
415 | resizeLTE s' arr = resize s' (believe_me ()) arr
416 |
417 |
418 | ||| List all of the values in an array along with their coordinates.
419 | export
420 | enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
421 | enumerateNB (MkArray _ s arr) =
422 |   map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
423 |
424 | ||| List all of the values in an array along with their coordinates.
425 | export
426 | enumerate : Array s a -> List (Coords s, a)
427 | enumerate {s} arr with (viewShape arr)
428 |   _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
429 |
430 | ||| List all of the values in an array in row-major order.
431 | export
432 | elements : Array {rk} s a -> Vect (product s) a
433 | elements (MkArray _ s arr) =
434 |   believe_me $ Vect.fromList $
435 |     map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
436 |
437 | ||| Join two arrays along a particular axis, e.g. combining two matrices
438 | ||| vertically or horizontally. All other axes of the arrays must have the
439 | ||| same dimensions.
440 | |||
441 | ||| @ axis The axis to join the arrays on
442 | export
443 | concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
444 |           Array (updateAt axis (+d) s) a
445 | concat {s,d} axis a b with (viewShape a, viewShape b)
446 |   _ | (Shape s, Shape (replaceAt axis d s)) =
447 |     believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
448 |                   @{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
449 |       (\is => let limit = index axis s
450 |               in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
451 |
452 | ||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
453 | |||
454 | ||| @ axis The axis to stack the arrays along
455 | export
456 | stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
457 | stack axis arrs = rewrite sym (lengthCorrect arrs) in
458 |   fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
459 |                            (i,is') => index is' (index i arrs))
460 |   where
461 |     getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
462 |     getAxisInd FZ (i :: is) = (i, is)
463 |     getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
464 |
465 | ||| Join the axes of a nested array structure to form a single array.
466 | export
467 | joinAxes : {s' : _} -> Array s (Array s' a) -> Array (s ++ s') a
468 | joinAxes {s} arr with (viewShape arr)
469 |   _ | Shape s = fromFunctionNB (s ++ s') (\is => arr !# takeUpTo s is !# dropUpTo s is)
470 |   where
471 |     takeUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk Nat
472 |     takeUpTo [] ys = []
473 |     takeUpTo (x::xs) (y::ys) = y :: takeUpTo xs ys
474 |
475 |     dropUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk' Nat
476 |     dropUpTo [] ys = ys
477 |     dropUpTo (x::xs) (y::ys) = dropUpTo xs ys
478 |
479 |
480 | ||| Split an array into a nested array structure along the specified axes.
481 | export
482 | splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
483 |             Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
484 | splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
485 |
486 | ||| Construct the transpose of an array by reversing the order of its axes.
487 | export
488 | transpose : Array s a -> Array (reverse s) a
489 | transpose {s} arr with (viewShape arr)
490 |   _ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
491 |
492 | ||| Construct the transpose of an array by reversing the order of its axes.
493 | |||
494 | ||| This is the postfix form of `transpose`.
495 | export
496 | (.T) : Array s a -> Array (reverse s) a
497 | (.T) = transpose
498 |
499 |
500 | ||| Swap two axes in an array.
501 | export
502 | swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a
503 | swapAxes {s} i j arr with (viewShape arr)
504 |   _ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is)
505 |
506 | ||| Apply a permutation to the axes of an array.
507 | export
508 | permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a
509 | permuteAxes {s} p arr with (viewShape arr)
510 |   _ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s)
511 |
512 | ||| Swap two coordinates along a specific axis (e.g. swapping two rows in a matrix).
513 | |||
514 | ||| @ axis The axis to swap the coordinates along. Slices of the array
515 | ||| perpendicular to this axis are taken when swapping.
516 | export
517 | swapInAxis : (axis : Fin rk) -> (i,j : Fin (index axis s)) -> Array s a -> Array s a
518 | swapInAxis {s} axis i j arr with (viewShape arr)
519 |   _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (swapValues i j) is)
520 |
521 | ||| Permute the coordinates along a specific axis (e.g. permuting the rows in
522 | ||| a matrix).
523 | |||
524 | ||| @ axis The axis to permute the coordinates along. Slices of the array
525 | ||| perpendicular to this axis are taken when permuting.
526 | export
527 | permuteInAxis : (axis : Fin rk) -> Permutation (index axis s) -> Array s a -> Array s a
528 | permuteInAxis {s} axis p arr with (viewShape arr)
529 |   _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (permuteValues p) is)
530 |
531 |
532 | --------------------------------------------------------------------------------
533 | -- Implementations
534 | --------------------------------------------------------------------------------
535 |
536 | export
537 | Zippable (Array s) where
538 |   zipWith {s} f a b with (viewShape a)
539 |     _ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
540 |                           @{mergeNCRepConstraint} s
541 |                           $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
542 |                               (\is => f (a !# is) (b !# is))
543 |
544 |   zipWith3 {s} f a b c with (viewShape a)
545 |     _ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
546 |                           @{mergeNCRepConstraint} s
547 |                           $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
548 |                               (\is => f (a !# is) (b !# is) (c !# is))
549 |
550 |   unzipWith {s} f arr with (viewShape arr)
551 |     _ | Shape s =
552 |       let rep : Rep
553 |           rep = forceRepNC $ getRep arr
554 |       in  (MkArray rep
555 |              @{forceRepConstraint} s
556 |              $ PrimArray.fromFunctionNB @{forceRepConstraint} _
557 |                  (\is => fst $ f (arr !# is)),
558 |            MkArray rep
559 |              @{forceRepConstraint} s
560 |              $ PrimArray.fromFunctionNB @{forceRepConstraint} _
561 |                  (\is => snd $ f (arr !# is)))
562 |
563 |   unzipWith3 {s} f arr with (viewShape arr)
564 |     _ | Shape s =
565 |       let rep : Rep
566 |           rep = forceRepNC $ getRep arr
567 |       in  (MkArray rep
568 |              @{forceRepConstraint} s
569 |              $ PrimArray.fromFunctionNB @{forceRepConstraint} _
570 |                  (\is => fst $ f (arr !# is)),
571 |            MkArray rep
572 |              @{forceRepConstraint} s
573 |              $ PrimArray.fromFunctionNB @{forceRepConstraint} _
574 |                  (\is => fst $ snd $ f (arr !# is)),
575 |            MkArray rep
576 |              @{forceRepConstraint} s
577 |              $ PrimArray.fromFunctionNB @{forceRepConstraint} _
578 |                  (\is => snd $ snd $ f (arr !# is)))
579 |
580 |
581 | export
582 | Functor (Array s) where
583 |   map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
584 |                                 (mapPrim @{forceRepConstraint} @{forceRepConstraint} f
585 |                                   $ convertRepPrim @{rc} @{forceRepConstraint} arr)
586 |
587 | export
588 | {s : _} -> Applicative (Array s) where
589 |   pure = repeat s
590 |   (<*>) = zipWith apply
591 |
592 | export
593 | {s : _} -> Monad (Array s) where
594 |   join arr = fromFunction s (\is => arr !! is !! is)
595 |
596 |
597 | -- Foldable and Traversable operate on the primitive array directly. This means
598 | -- that their operation is dependent on the internal representation of the
599 | -- array.
600 |
601 | export
602 | Foldable (Array s) where
603 |   foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
604 |   foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
605 |   null (MkArray _ s _) = isZero (product s)
606 |
607 |
608 | export
609 | Traversable (Array s) where
610 |   traverse f (MkArray rep @{rc} s arr) =
611 |     map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
612 |       (PrimArray.traverse {rep=forceRepNC rep}
613 |         @{%search} @{forceRepConstraint} @{forceRepConstraint} f
614 |         (convertRepPrim @{rc} @{forceRepConstraint} arr))
615 |
616 |
617 | export
618 | Cast a b => Cast (Array s a) (Array s b) where
619 |   cast = map cast
620 |
621 | export
622 | Eq a => Eq (Array s a) where
623 |   a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
624 |
625 | export
626 | Semigroup a => Semigroup (Array s a) where
627 |   (<+>) = zipWithArray' (<+>)
628 |
629 | export
630 | {s : _} -> Monoid a => Monoid (Array s a) where
631 |   neutral = repeat s neutral
632 |
633 | -- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
634 | -- were moved into its own interface, this constraint could be removed.
635 | export
636 | {s : _} -> Num a => Num (Array s a) where
637 |   (+) = zipWithArray' (+)
638 |   (*) = zipWithArray' (*)
639 |
640 |   fromInteger = repeat s . fromInteger
641 |
642 | export
643 | {s : _} -> Neg a => Neg (Array s a) where
644 |   negate = mapArray' negate
645 |   (-) = zipWithArray' (-)
646 |
647 | export
648 | {s : _} -> Fractional a => Fractional (Array s a) where
649 |   recip = mapArray' recip
650 |   (/) = zipWithArray' (/)
651 |
652 |
653 | export
654 | Num a => Mult a (Array {rk} s a) (Array s a) where
655 |   (*.) x = mapArray' (*x)
656 |
657 | export
658 | Num a => Mult (Array {rk} s a) a (Array s a) where
659 |   (*.) = flip (*.)
660 |
661 |
662 | export
663 | Show a => Show (Array s a) where
664 |   showPrec d arr = let orderedElems = toList $ elements arr
665 |                    in  showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
666 |     where
667 |       splitWindow : Nat -> List String -> List (List String)
668 |       splitWindow n xs = case splitAt n xs of
669 |                            (xs, []) => [xs]
670 |                            (l1, l2) => l1 :: splitWindow n (assert_smaller xs l2)
671 |
672 |       insertPunct : Vect rk Nat -> List String -> List String
673 |       insertPunct [] strs = strs
674 |       insertPunct [d] strs = "[" :: intersperse ", " strs `snoc` "]"
675 |       insertPunct (Z :: s) strs = ["[","]"]
676 |       insertPunct (d :: s) strs =
677 |         let secs = if null strs
678 |                       then List.replicate d ("[]" :: Prelude.Nil)
679 |                       else map (insertPunct s) $ splitWindow (length strs `div` d) strs
680 |         in  "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
681 |
682 |
683 | --------------------------------------------------------------------------------
684 | -- Numeric array operations
685 | --------------------------------------------------------------------------------
686 |
687 |
688 | ||| Linearly interpolate between two arrays.
689 | export
690 | lerp : Neg a => a -> Array s a -> Array s a -> Array s a
691 | lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
692 |
693 |
694 | ||| Calculate the square of an array's Eulidean norm.
695 | export
696 | normSq : Num a => Array s a -> a
697 | normSq arr = sum $ zipWith (*) arr arr
698 |
699 | ||| Calculate an array's Eucliean norm.
700 | export
701 | norm : Array s Double -> Double
702 | norm = sqrt . normSq
703 |
704 | ||| Normalize the array to a norm of 1.
705 | |||
706 | ||| If the array contains all zeros, then it is returned unchanged.
707 | export
708 | normalize : Array s Double -> Array s Double
709 | normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
710 |
711 | ||| Calculate the Lp-norm of an array.
712 | export
713 | pnorm : (p : Double) -> Array s Double -> Double
714 | pnorm p = (`pow` recip p) . sum . map (`pow` p)
715 |