0 | module Data.NumIdr.Matrix
  1 |
  2 | import Data.List
  3 | import Data.Vect
  4 | import Data.Permutation
  5 | import Data.NumIdr.Interfaces
  6 | import public Data.NumIdr.Array
  7 | import Data.NumIdr.PrimArray
  8 | import Data.NumIdr.Vector
  9 |
 10 | %default total
 11 |
 12 |
 13 | ||| A matrix is a rank-2 array.
 14 | public export
 15 | Matrix : Nat -> Nat -> Type -> Type
 16 | Matrix m n = Array [m,n]
 17 |
 18 | %name Matrix mat
 19 |
 20 | ||| A synonym for a square matrix with dimensions of length `n`.
 21 | public export
 22 | Matrix' : Nat -> Type -> Type
 23 | Matrix' n = Array [n,n]
 24 |
 25 |
 26 | --------------------------------------------------------------------------------
 27 | -- Matrix constructors
 28 | --------------------------------------------------------------------------------
 29 |
 30 |
 31 | ||| Construct a matrix with the given order and elements.
 32 | export
 33 | matrix : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
 34 |           Vect m (Vect n a) -> Matrix m n a
 35 | matrix x = array {rep, s=[m,n]} x
 36 |
 37 |
 38 | ||| Construct a matrix with a specific value along the diagonal.
 39 | |||
 40 | ||| @ diag  The value to repeat along the diagonal
 41 | ||| @ other The value to repeat elsewhere
 42 | export
 43 | repeatDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
 44 |               (diag, other : a) -> Matrix m n a
 45 | repeatDiag d o = fromFunctionNB {rep} [m,n]
 46 |    (\[i,j] => if i == j then d else o)
 47 |
 48 | ||| Construct a matrix given its diagonal elements.
 49 | |||
 50 | ||| @ diag  The elements of the matrix's diagonal
 51 | ||| @ other The value to repeat elsewhere
 52 | export
 53 | fromDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
 54 |             (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
 55 | fromDiag ds o = fromFunction {rep} [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
 56 |   where
 57 |     eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
 58 |     eq FZ FZ = Just FZ
 59 |     eq (FS x) (FS y) = map FS (eq x y)
 60 |     eq FZ (FS _) = Nothing
 61 |     eq (FS _) FZ = Nothing
 62 |
 63 |
 64 | ||| Construct a permutation matrix based on the given permutation.
 65 | export
 66 | permuteM : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
 67 |             Permutation n -> Matrix' n a
 68 | permuteM p = permuteInAxis 0 p (repeatDiag {rep} 1 0)
 69 |
 70 |
 71 | ||| Construct the matrix that scales a vector by the given value.
 72 | export
 73 | scale : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
 74 |           a -> Matrix' n a
 75 | scale x = repeatDiag {rep} x 0
 76 |
 77 | ||| Construct a 2D rotation matrix that rotates by the given angle (in radians).
 78 | export
 79 | rotate2D : {default B rep : Rep} -> RepConstraint rep Double =>
 80 |             Double -> Matrix' 2 Double
 81 | rotate2D a = matrix {rep} [[cos a, - sin a], [sin a, cos a]]
 82 |
 83 |
 84 | ||| Construct a 3D rotation matrix around the x-axis.
 85 | export
 86 | rotate3DX : {default B rep : Rep} -> RepConstraint rep Double =>
 87 |             Double -> Matrix' 3 Double
 88 | rotate3DX a = matrix {rep} [[1,0,0], [0, cos a, - sin a], [0, sin a, cos a]]
 89 |
 90 | ||| Construct a 3D rotation matrix around the y-axis.
 91 | export
 92 | rotate3DY : {default B rep : Rep} -> RepConstraint rep Double =>
 93 |             Double -> Matrix' 3 Double
 94 | rotate3DY a = matrix {rep} [[cos a, 0, sin a], [0,1,0], [- sin a, 0, cos a]]
 95 |
 96 | ||| Construct a 3D rotation matrix around the z-axis.
 97 | export
 98 | rotate3DZ : {default B rep : Rep} -> RepConstraint rep Double =>
 99 |             Double -> Matrix' 3 Double
100 | rotate3DZ a = matrix {rep} [[cos a, - sin a, 0], [sin a, cos a, 0], [0,0,1]]
101 |
102 |
103 | export
104 | reflect : {default B rep : Rep} -> RepConstraint rep a =>
105 |           {n : _} -> Neg a => Fin n -> Matrix' n a
106 | reflect i = indexSet [i, i] (-1) (repeatDiag {rep} 1 0)
107 |
108 | export
109 | reflectX : {default B rep : Rep} -> RepConstraint rep a =>
110 |             {n : _} -> Neg a => Matrix' (1 + n) a
111 | reflectX = reflect {rep} 0
112 |
113 | export
114 | reflectY : {default B rep : Rep} -> RepConstraint rep a =>
115 |             {n : _} -> Neg a => Matrix' (2 + n) a
116 | reflectY = reflect {rep} 1
117 |
118 | export
119 | reflectZ : {default B rep : Rep} -> RepConstraint rep a =>
120 |             {n : _} -> Neg a => Matrix' (3 + n) a
121 | reflectZ = reflect {rep} 2
122 |
123 |
124 | --------------------------------------------------------------------------------
125 | -- Indexing
126 | --------------------------------------------------------------------------------
127 |
128 |
129 | ||| Index the matrix at the given coordinates.
130 | export
131 | index : Fin m -> Fin n -> Matrix m n a -> a
132 | index m n = index [m,n]
133 |
134 | ||| Index the matrix at the given coordinates, returning `Nothing` if the
135 | ||| coordinates are out of bounds.
136 | export
137 | indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
138 | indexNB m n = indexNB [m,n]
139 |
140 |
141 | ||| Return a row of the matrix as a vector.
142 | export
143 | getRow : Fin m -> Matrix m n a -> Vector n a
144 | getRow r mat = rewrite sym (rangeLenZ n) in mat!!..[One r, All]
145 |
146 | ||| Return a column of the matrix as a vector.
147 | export
148 | getColumn : Fin n -> Matrix m n a -> Vector m a
149 | getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c]
150 |
151 |
152 | ||| Return the diagonal elements of the matrix as a vector.
153 | export
154 | diagonal' : Matrix m n a -> Vector (minimum m n) a
155 | diagonal' {m,n} mat with (viewShape mat)
156 |   _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat} _ (\[i] => mat!#[i,i])
157 |
158 | ||| Return the diagonal elements of the matrix as a vector.
159 | export
160 | diagonal : Matrix' n a -> Vector n a
161 | diagonal {n} mat with (viewShape mat)
162 |   _ | Shape [n,n] = fromFunctionNB {rep=_} @{getRepC mat} [n] (\[i] => mat!#[i,i])
163 |
164 |
165 | ||| Return a minor of the matrix, i.e. the matrix formed by removing a
166 | ||| single row and column.
167 | export
168 | -- TODO: throw an actual proof in here to avoid the unsafety
169 | minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a
170 | minor i j mat = replace {p = flip Array a} (believe_me $ Refl {x = ()})
171 |   $ mat!!..[Filter (/=i), Filter (/=j)]
172 |
173 |
174 | filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a
175 | filterInd {m,n} p mat with (viewShape mat)
176 |   _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat}
177 |                       [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0)
178 |
179 | export
180 | upperTriangle : Num a => Matrix m n a -> Matrix m n a
181 | upperTriangle = filterInd (<=)
182 |
183 | export
184 | lowerTriangle : Num a => Matrix m n a -> Matrix m n a
185 | lowerTriangle = filterInd (>=)
186 |
187 | export
188 | upperTriangleStrict : Num a => Matrix m n a -> Matrix m n a
189 | upperTriangleStrict = filterInd (<)
190 |
191 | export
192 | lowerTriangleStrict : Num a => Matrix m n a -> Matrix m n a
193 | lowerTriangleStrict = filterInd (>)
194 |
195 |
196 | --------------------------------------------------------------------------------
197 | -- Basic operations
198 | --------------------------------------------------------------------------------
199 |
200 | ||| Concatenate two matrices vertically.
201 | export
202 | vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
203 | vconcat = concat 0
204 |
205 | ||| Concatenate two matrices horizontally.
206 | export
207 | hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
208 | hconcat = concat 1
209 |
210 | ||| Stack row vectors to form a matrix.
211 | export
212 | vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a
213 | vstack = stack 0
214 |
215 | ||| Stack column vectors to form a matrix.
216 | export
217 | hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a
218 | hstack = stack 1
219 |
220 |
221 | ||| Swap two rows of a matrix.
222 | export
223 | swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a
224 | swapRows = swapInAxis 0
225 |
226 | ||| Swap two columns of a matrix.
227 | export
228 | swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a
229 | swapColumns = swapInAxis 1
230 |
231 | ||| Permute the rows of a matrix.
232 | export
233 | permuteRows : Permutation m -> Matrix m n a -> Matrix m n a
234 | permuteRows = permuteInAxis 0
235 |
236 | ||| Permute the columns of a matrix.
237 | export
238 | permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a
239 | permuteColumns = permuteInAxis 1
240 |
241 |
242 | ||| Calculate the outer product of two vectors as a matrix.
243 | export
244 | outer : Num a => Vector m a -> Vector n a -> Matrix m n a
245 | outer {m,n} a b with (viewShape a, viewShape b)
246 |   _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j)
247 |
248 |
249 | ||| Calculate the trace of a matrix, i.e. the sum of its diagonal elements.
250 | export
251 | trace : Num a => Matrix m n a -> a
252 | trace = sum . diagonal'
253 |
254 |
255 | ||| Construct a matrix that reflects a vector along a hyperplane of the
256 | ||| given normal vector. The input does not have to be a unit vector.
257 | export
258 | reflectNormal : (Neg a, Fractional a) => Vector n a -> Matrix' n a
259 | reflectNormal {n} v with (viewShape v)
260 |   _ | Shape [n] = repeatDiag 1 0 - (2 / normSq v) *. outer v v
261 |
262 |
263 | --------------------------------------------------------------------------------
264 | -- Matrix multiplication
265 | --------------------------------------------------------------------------------
266 |
267 |
268 | export
269 | Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
270 |   (*.) {m,n} mat v with (viewShape mat)
271 |     _ | Shape [m,n] = fromFunction {rep=_}
272 |       @{mergeRepConstraint (getRepC mat) (getRepC v)} [m]
273 |       (\[i] => sum $ map (\j => mat!![i,j] * v!!j) range)
274 |
275 | export
276 | Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
277 |   (*.) {m,n,p} m1 m2 with (viewShape m1, viewShape m2)
278 |     _ | (Shape [m,n], Shape [n,p]) = fromFunction {rep=_}
279 |       @{mergeRepConstraint (getRepC m1) (getRepC m2)} [m,p]
280 |       (\[i,j] => sum $ map (\k => m1!![i,k] * m2!![k,j]) range)
281 |
282 | export
283 | {n : _} -> Num a => MultMonoid (Matrix' n a) where
284 |   identity = repeatDiag 1 0
285 |
286 |
287 | --------------------------------------------------------------------------------
288 | -- Matrix decomposition
289 | --------------------------------------------------------------------------------
290 |
291 |
292 | ||| The LU decomposition of a matrix.
293 | |||
294 | ||| LU decomposition factors a matrix A into two matrices: a lower triangular
295 | ||| matrix L, and an upper triangular matrix U, such that A = LU.
296 | export
297 | record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where
298 |   constructor MkLU
299 |   -- The lower and upper triangular matrix elements are stored
300 |   -- together for efficiency reasons
301 |   lu : Matrix m n a
302 |
303 |
304 | namespace DecompLU
305 |   ||| The lower triangular matrix L of the LU decomposition.
306 |   export
307 |   lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
308 |   lower {m,n} (MkLU lu) with (viewShape lu)
309 |     _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
310 |       case compare i j of
311 |         LT => 0
312 |         EQ => 1
313 |         GT => lu!#[i,j])
314 |
315 |   ||| The lower triangular matrix L of the LU decomposition.
316 |   export %inline
317 |   (.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
318 |   (.lower) = lower
319 |
320 |   ||| The lower triangular matrix L of the LU decomposition.
321 |   |||
322 |   ||| This accessor is intended to be used for square matrix decompositions.
323 |   export
324 |   lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
325 |   lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n)
326 |               in lower lu
327 |
328 |   ||| The lower triangular matrix L of the LU decomposition.
329 |   |||
330 |   ||| This accessor is intended to be used for square matrix decompositions.
331 |   export %inline
332 |   (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
333 |   (.lower') = lower'
334 |
335 |   ||| The upper triangular matrix U of the LU decomposition.
336 |   export
337 |   upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
338 |   upper {m,n} (MkLU lu) with (viewShape lu)
339 |     _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
340 |       if i <= j then lu!#[i,j] else 0)
341 |
342 |   ||| The upper triangular matrix U of the LU decomposition.
343 |   export %inline
344 |   (.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
345 |   (.upper) = upper
346 |
347 |   ||| The upper triangular matrix U of the LU decomposition.
348 |   |||
349 |   ||| This accessor is intended to be used for square matrix decompositions.
350 |   export
351 |   upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
352 |   upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n)
353 |               in upper lu
354 |
355 |   ||| The upper triangular matrix U of the LU decomposition.
356 |   |||
357 |   ||| This accessor is intended to be used for square matrix decompositions.
358 |   export %inline
359 |   (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
360 |   (.upper') = upper'
361 |
362 |
363 | minWeakenLeft : {m,n : _} -> Fin (minimum m n) -> Fin m
364 | minWeakenLeft x = weakenLTE x $ minLTE m n
365 |   where
366 |     minLTE : (m,n : _) -> minimum m n `LTE` m
367 |     minLTE Z n = LTEZero
368 |     minLTE (S m) Z = LTEZero
369 |     minLTE (S m) (S n) = LTESucc (minLTE m n)
370 |
371 | minWeakenRight : {m,n : _} -> Fin (minimum m n) -> Fin n
372 | minWeakenRight x = weakenLTE x $ minLTE m n
373 |   where
374 |     minLTE : (m,n : _) -> minimum m n `LTE` n
375 |     minLTE Z n = LTEZero
376 |     minLTE (S m) Z = LTEZero
377 |     minLTE (S m) (S n) = LTESucc (minLTE m n)
378 |
379 |
380 | iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a
381 | iterateN 0 f x = x
382 | iterateN 1 f x = f FZ x
383 | iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
384 |
385 |
386 | ||| Perform a single step of Gaussian elimination on the `i`-th row and column.
387 | gaussStep : {m,n : _} -> Field a =>
388 |             Fin (minimum m n) -> Matrix m n a -> Matrix m n a
389 | gaussStep i lu =
390 |     if all (==0) $ getColumn (minWeakenRight i) lu then lu else
391 |       let ir = minWeakenLeft i
392 |           ic = minWeakenRight i
393 |           diag = lu!![ir,ic]
394 |           coeffs = map (/diag) $ lu!!..[StartBound (FS ir), One ic]
395 |           lu' = indexSetRange [StartBound (FS ir), One ic] coeffs lu
396 |           pivot = lu!!..[One ir, StartBound (FS ic)]
397 |           offsets = outer coeffs pivot
398 |       in  indexUpdateRange [StartBound (FS ir), StartBound (FS ic)]
399 |             (flip (-) offsets) lu'
400 |
401 | ||| Calculate the LU decomposition of a matrix, returning `Nothing` if one
402 | ||| does not exist.
403 | export
404 | decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
405 | decompLU {m,n} mat with (viewShape mat)
406 |   _ | Shape [m,n] = map MkLU
407 |     $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat)
408 |   where
409 |     gaussStepMaybe : Fin (minimum m n) -> Matrix m n a -> Maybe (Matrix m n a)
410 |     gaussStepMaybe i mat = if mat!#[cast i,cast i] == 0 then Nothing
411 |                            else Just $ gaussStep i mat
412 |
413 |
414 | ||| The LUP decomposition of a matrix.
415 | |||
416 | ||| LUP decomposition is similar to LU decomposition, but the matrix may have
417 | ||| its rows permuted before being factored. More formally, an LUP decomposition
418 | ||| of a matrix A consists of a lower triangular matrix L, an upper triangular
419 | ||| matrix U, and a permutation matrix P, such that PA = LU.
420 | export
421 | record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where
422 |   constructor MkLUP
423 |   lu : Matrix m n a
424 |   p : Permutation m
425 |   sw : Nat
426 |
427 | namespace DecompLUP
428 |   ||| The lower triangular matrix L of the LUP decomposition.
429 |   export
430 |   lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
431 |   lower {m,n} (MkLUP lu _ _) with (viewShape lu)
432 |     _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
433 |       case compare i j of
434 |         LT => 0
435 |         EQ => 1
436 |         GT => lu!#[i,j])
437 |
438 |   ||| The lower triangular matrix L of the LUP decomposition.
439 |   export %inline
440 |   (.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
441 |   (.lower) = lower
442 |
443 |   ||| The lower triangular matrix L of the LUP decomposition.
444 |   |||
445 |   ||| This accessor is intended to be used for square matrix decompositions.
446 |   export
447 |   lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
448 |   lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n)
449 |               in lower lu
450 |
451 |   ||| The lower triangular matrix L of the LUP decomposition.
452 |   |||
453 |   ||| This accessor is intended to be used for square matrix decompositions.
454 |   export %inline
455 |   (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
456 |   (.lower') = lower'
457 |
458 |   ||| The upper triangular matrix U of the LUP decomposition.
459 |   export
460 |   upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
461 |   upper {m,n} (MkLUP lu _ _) with (viewShape lu)
462 |     _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
463 |       if i <= j then lu!#[i,j] else 0)
464 |
465 |   ||| The upper triangular matrix U of the LUP decomposition.
466 |   export %inline
467 |   (.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
468 |   (.upper) = upper
469 |
470 |   ||| The upper triangular matrix U of the LUP decomposition.
471 |   |||
472 |   ||| This accessor is intended to be used for square matrix decompositions.
473 |   export
474 |   upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
475 |   upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n)
476 |               in upper lu
477 |
478 |   ||| The upper triangular matrix U of the LUP decomposition.
479 |   |||
480 |   ||| This accessor is intended to be used for square matrix decompositions.
481 |   export %inline
482 |   (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
483 |   (.upper') = upper'
484 |
485 |   ||| The row permutation of the LUP decomposition.
486 |   export
487 |   permute : DecompLUP {m} mat -> Permutation m
488 |   permute (MkLUP lu p sw) = p
489 |
490 |   ||| The row permutation of the LUP decomposition.
491 |   export %inline
492 |   (.permute) : DecompLUP {m} mat -> Permutation m
493 |   (.permute) = permute
494 |
495 |   ||| The number of swaps in the permutation of the LUP decomposition.
496 |   |||
497 |   ||| This is stored along with the permutation in order to increase the
498 |   ||| efficiency of certain algorithms.
499 |   export
500 |   numSwaps : DecompLUP mat -> Nat
501 |   numSwaps (MkLUP lu p sw) = sw
502 |
503 |
504 | ||| Convert an LU decomposition into an LUP decomposition.
505 | export
506 | fromLU : DecompLU mat -> DecompLUP mat
507 | fromLU (MkLU lu) = MkLUP lu identity 0
508 |
509 | ||| Calculate the LUP decomposition of a matrix.
510 | export
511 | decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat
512 | decompLUP {m,n} mat with (viewShape mat)
513 |   decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
514 |   decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
515 |   decompLUP {m=S m,n=S n} mat | Shape [S m,S n] =
516 |     iterateN (S $ minimum m n) gaussStepSwap (MkLUP mat identity 0)
517 |   where
518 |     maxIndex : (s,a) -> List (s,a) -> (s,a)
519 |     maxIndex x [] = x
520 |     maxIndex _ [x] = x
521 |     maxIndex x ((a,b)::(c,d)::xs) =
522 |       if abslt b d then maxIndex x ((c,d)::xs)
523 |                    else assert_total $ maxIndex x ((a,b)::xs)
524 |
525 |     gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat
526 |     gaussStepSwap i (MkLUP lu p sw) =
527 |       let ir = minWeakenLeft {n=S n} i
528 |           ic = minWeakenRight {m=S m} i
529 |           maxi = head $ fst (maxIndex ([0],0) $ drop (cast i) $ enumerate $
530 |                            indexSetRange [EndBound (weaken ir)] 0 $ getColumn ic lu)
531 |       in  if maxi == ir then MkLUP (gaussStep i lu) p sw
532 |           else MkLUP (gaussStep i $ swapRows ir maxi lu) (appendSwap maxi ir p) (S sw)
533 |
534 |
535 | --------------------------------------------------------------------------------
536 | -- Determinant
537 | --------------------------------------------------------------------------------
538 |
539 |
540 | ||| Calculate the determinant of a matrix given its LUP decomposition.
541 | export
542 | detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a
543 | detWithLUP mat lup =
544 |   (if numSwaps lup `mod` 2 == 0 then 1 else -1)
545 |     * product (diagonal lup.lu)
546 |
547 | ||| Calculate the determinant of a matrix.
548 | export
549 | det : FieldCmp a => Matrix' n a -> a
550 | det {n} mat with (viewShape mat)
551 |   det {n=0} mat | Shape [0,0] = 1
552 |   det {n=1} mat | Shape [1,1] = mat!![0,0]
553 |   det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
554 |   _ | Shape [n,n] = detWithLUP mat (decompLUP mat)
555 |
556 |
557 | --------------------------------------------------------------------------------
558 | -- Solving matrix equations
559 | --------------------------------------------------------------------------------
560 |
561 |
562 | solveLowerTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
563 | solveLowerTri' {n} mat b with (viewShape b)
564 |   _ | Shape [n] = vector $ reverse $ construct $ reverse $ toVect b
565 |   where
566 |     construct : {i : _} -> Vect i a -> Vect i a
567 |     construct [] = []
568 |     construct {i=S i} (b :: bs) =
569 |       let xs = construct bs
570 |       in (b - sum (zipWith (*) xs (reverse $ toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $
571 |                   mat !#.. [One i, EndBound i]))) / mat!#[i,i] :: xs
572 |
573 |
574 | solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
575 | solveUpperTri' {n} mat b with (viewShape b)
576 |   _ | Shape [n] = vector $ construct Z $ toVect b
577 |   where
578 |     construct : Nat -> Vect i a -> Vect i a
579 |     construct _ [] = []
580 |     construct i (b :: bs) =
581 |       let xs = construct (S i) bs
582 |       in (b - sum (zipWith (*) xs (toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $
583 |                   mat !#.. [One i, StartBound (S i)]))) / mat!#[i,i] :: xs
584 |
585 |
586 | ||| Solve a linear equation, assuming the matrix is lower triangular.
587 | ||| Any entries other than those below the diagonal are ignored.
588 | export
589 | solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
590 | solveLowerTri mat b = if all (/=0) (diagonal mat)
591 |                       then Just $ solveLowerTri' mat b
592 |                       else Nothing
593 |
594 | ||| Solve a linear equation, assuming the matrix is upper triangular.
595 | ||| Any entries other than those above the diagonal are ignored.
596 | export
597 | solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
598 | solveUpperTri mat b = if all (/=0) (diagonal mat)
599 |                       then Just $ solveUpperTri' mat b
600 |                       else Nothing
601 |
602 |
603 | solveWithLUP' : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
604 |                 Vector n a -> Vector n a
605 | solveWithLUP' mat lup b =
606 |   let b' = permuteCoords (inverse lup.permute) b
607 |   in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b'
608 |
609 | ||| Solve a linear equation, given a matrix and its LUP decomposition.
610 | export
611 | solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
612 |                 Vector n a -> Maybe (Vector n a)
613 | solveWithLUP mat lup b =
614 |   let b' = permuteCoords (inverse lup.permute) b
615 |   in solveUpperTri lup.upper' $ solveLowerTri' lup.lower' b'
616 |
617 | ||| Solve a linear equation given a matrix.
618 | export
619 | solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
620 | solve mat = solveWithLUP mat (decompLUP mat)
621 |
622 |
623 | --------------------------------------------------------------------------------
624 | -- Matrix inversion
625 | --------------------------------------------------------------------------------
626 |
627 |
628 | ||| Determine whether a matrix has an inverse.
629 | export
630 | invertible : FieldCmp a => Matrix' n a -> Bool
631 | invertible {n} mat with (viewShape mat)
632 |   _ | Shape [n,n] = let lup = decompLUP mat in all (/=0) (diagonal lup.lu)
633 |
634 | ||| Try to invert a square matrix, returning `Nothing` if an inverse
635 | ||| does not exist.
636 | export
637 | tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a)
638 | tryInverse {n} mat with (viewShape mat)
639 |   _ | Shape [n,n] =
640 |     let lup = decompLUP mat
641 |     in map hstack $ traverse (solveWithLUP mat lup) $ map basis range
642 |
643 |
644 | export
645 | {n : _} -> FieldCmp a => MultGroup (Matrix' n a) where
646 |   inverse mat = let lup = decompLUP mat in
647 |         hstack $ map (solveWithLUP' mat lup . basis) range
648 |