0 | module Data.NumIdr.Matrix
4 | import Data.Permutation
5 | import Data.NumIdr.Interfaces
6 | import public Data.NumIdr.Array
7 | import Data.NumIdr.PrimArray
8 | import Data.NumIdr.Vector
15 | Matrix : Nat -> Nat -> Type -> Type
16 | Matrix m n = Array [m,n]
22 | Matrix' : Nat -> Type -> Type
23 | Matrix' n = Array [n,n]
33 | matrix : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
34 | Vect m (Vect n a) -> Matrix m n a
35 | matrix x = array {rep, s=[m,n]} x
43 | repeatDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
44 | (diag, other : a) -> Matrix m n a
45 | repeatDiag d o = fromFunctionNB {rep} [m,n]
46 | (\[i,j] => if i == j then d else o)
53 | fromDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
54 | (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
55 | fromDiag ds o = fromFunction {rep} [m,n] (\[i,j] => maybe o (`index` ds) $
i `eq` j)
57 | eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
59 | eq (FS x) (FS y) = map FS (eq x y)
60 | eq FZ (FS _) = Nothing
61 | eq (FS _) FZ = Nothing
66 | permuteM : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
67 | Permutation n -> Matrix' n a
68 | permuteM p = permuteInAxis 0 p (repeatDiag {rep} 1 0)
73 | scale : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
75 | scale x = repeatDiag {rep} x 0
79 | rotate2D : {default B rep : Rep} -> RepConstraint rep Double =>
80 | Double -> Matrix' 2 Double
81 | rotate2D a = matrix {rep} [[cos a, - sin a], [sin a, cos a]]
86 | rotate3DX : {default B rep : Rep} -> RepConstraint rep Double =>
87 | Double -> Matrix' 3 Double
88 | rotate3DX a = matrix {rep} [[1,0,0], [0, cos a, - sin a], [0, sin a, cos a]]
92 | rotate3DY : {default B rep : Rep} -> RepConstraint rep Double =>
93 | Double -> Matrix' 3 Double
94 | rotate3DY a = matrix {rep} [[cos a, 0, sin a], [0,1,0], [- sin a, 0, cos a]]
98 | rotate3DZ : {default B rep : Rep} -> RepConstraint rep Double =>
99 | Double -> Matrix' 3 Double
100 | rotate3DZ a = matrix {rep} [[cos a, - sin a, 0], [sin a, cos a, 0], [0,0,1]]
104 | reflect : {default B rep : Rep} -> RepConstraint rep a =>
105 | {n : _} -> Neg a => Fin n -> Matrix' n a
106 | reflect i = indexSet [i, i] (-
1) (repeatDiag {rep} 1 0)
109 | reflectX : {default B rep : Rep} -> RepConstraint rep a =>
110 | {n : _} -> Neg a => Matrix' (1 + n) a
111 | reflectX = reflect {rep} 0
114 | reflectY : {default B rep : Rep} -> RepConstraint rep a =>
115 | {n : _} -> Neg a => Matrix' (2 + n) a
116 | reflectY = reflect {rep} 1
119 | reflectZ : {default B rep : Rep} -> RepConstraint rep a =>
120 | {n : _} -> Neg a => Matrix' (3 + n) a
121 | reflectZ = reflect {rep} 2
131 | index : Fin m -> Fin n -> Matrix m n a -> a
132 | index m n = index [m,n]
137 | indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
138 | indexNB m n = indexNB [m,n]
143 | getRow : Fin m -> Matrix m n a -> Vector n a
144 | getRow r mat = rewrite sym (rangeLenZ n) in mat!!..[One r, All]
148 | getColumn : Fin n -> Matrix m n a -> Vector m a
149 | getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c]
154 | diagonal' : Matrix m n a -> Vector (minimum m n) a
155 | diagonal' {m,n} mat with (viewShape mat)
156 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat} _ (\[i] => mat!#[i,i])
160 | diagonal : Matrix' n a -> Vector n a
161 | diagonal {n} mat with (viewShape mat)
162 | _ | Shape [n,n] = fromFunctionNB {rep=_} @{getRepC mat} [n] (\[i] => mat!#[i,i])
169 | minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a
170 | minor i j mat = replace {p = flip Array a} (believe_me $
Refl {x = ()})
171 | $
mat!!..[Filter (/=i), Filter (/=j)]
174 | filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a
175 | filterInd {m,n} p mat with (viewShape mat)
176 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat}
177 | [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0)
180 | upperTriangle : Num a => Matrix m n a -> Matrix m n a
181 | upperTriangle = filterInd (<=)
184 | lowerTriangle : Num a => Matrix m n a -> Matrix m n a
185 | lowerTriangle = filterInd (>=)
188 | upperTriangleStrict : Num a => Matrix m n a -> Matrix m n a
189 | upperTriangleStrict = filterInd (<)
192 | lowerTriangleStrict : Num a => Matrix m n a -> Matrix m n a
193 | lowerTriangleStrict = filterInd (>)
202 | vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
207 | hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
212 | vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a
217 | hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a
223 | swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a
224 | swapRows = swapInAxis 0
228 | swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a
229 | swapColumns = swapInAxis 1
233 | permuteRows : Permutation m -> Matrix m n a -> Matrix m n a
234 | permuteRows = permuteInAxis 0
238 | permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a
239 | permuteColumns = permuteInAxis 1
244 | outer : Num a => Vector m a -> Vector n a -> Matrix m n a
245 | outer {m,n} a b with (viewShape a, viewShape b)
246 | _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j)
251 | trace : Num a => Matrix m n a -> a
252 | trace = sum . diagonal'
258 | reflectNormal : (Neg a, Fractional a) => Vector n a -> Matrix' n a
259 | reflectNormal {n} v with (viewShape v)
260 | _ | Shape [n] = repeatDiag 1 0 - (2 / normSq v) *. outer v v
269 | Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
270 | (*.) {m,n} mat v with (viewShape mat)
271 | _ | Shape [m,n] = fromFunction {rep=_}
272 | @{mergeRepConstraint (getRepC mat) (getRepC v)} [m]
273 | (\[i] => sum $
map (\j => mat!![i,j] * v!!j) range)
276 | Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
277 | (*.) {m,n,p} m1 m2 with (viewShape m1, viewShape m2)
278 | _ | (Shape [m,n], Shape [n,p]) = fromFunction {rep=_}
279 | @{mergeRepConstraint (getRepC m1) (getRepC m2)} [m,p]
280 | (\[i,j] => sum $
map (\k => m1!![i,k] * m2!![k,j]) range)
283 | {n : _} -> Num a => MultMonoid (Matrix' n a) where
284 | identity = repeatDiag 1 0
297 | record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where
307 | lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
308 | lower {m,n} (MkLU lu) with (viewShape lu)
309 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
310 | case compare i j of
317 | (.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
324 | lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
325 | lower' lu = rewrite cong (\i => Matrix n i a) $
sym (minimumIdempotent n)
332 | (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
337 | upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
338 | upper {m,n} (MkLU lu) with (viewShape lu)
339 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
340 | if i <= j then lu!#[i,j] else 0)
344 | (.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
351 | upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
352 | upper' lu = rewrite cong (\i => Matrix i n a) $
sym (minimumIdempotent n)
359 | (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
363 | minWeakenLeft : {m,n : _} -> Fin (minimum m n) -> Fin m
364 | minWeakenLeft x = weakenLTE x $
minLTE m n
366 | minLTE : (m,n : _) -> minimum m n `LTE` m
367 | minLTE Z n = LTEZero
368 | minLTE (S m) Z = LTEZero
369 | minLTE (S m) (S n) = LTESucc (minLTE m n)
371 | minWeakenRight : {m,n : _} -> Fin (minimum m n) -> Fin n
372 | minWeakenRight x = weakenLTE x $
minLTE m n
374 | minLTE : (m,n : _) -> minimum m n `LTE` n
375 | minLTE Z n = LTEZero
376 | minLTE (S m) Z = LTEZero
377 | minLTE (S m) (S n) = LTESucc (minLTE m n)
380 | iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a
382 | iterateN 1 f x = f FZ x
383 | iterateN (S n@(S _)) f x = iterateN n (f . FS) $
f FZ x
387 | gaussStep : {m,n : _} -> Field a =>
388 | Fin (minimum m n) -> Matrix m n a -> Matrix m n a
390 | if all (==0) $
getColumn (minWeakenRight i) lu then lu else
391 | let ir = minWeakenLeft i
392 | ic = minWeakenRight i
394 | coeffs = map (/diag) $
lu!!..[StartBound (FS ir), One ic]
395 | lu' = indexSetRange [StartBound (FS ir), One ic] coeffs lu
396 | pivot = lu!!..[One ir, StartBound (FS ic)]
397 | offsets = outer coeffs pivot
398 | in indexUpdateRange [StartBound (FS ir), StartBound (FS ic)]
399 | (flip (-) offsets) lu'
404 | decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
405 | decompLU {m,n} mat with (viewShape mat)
406 | _ | Shape [m,n] = map MkLU
407 | $
iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat)
409 | gaussStepMaybe : Fin (minimum m n) -> Matrix m n a -> Maybe (Matrix m n a)
410 | gaussStepMaybe i mat = if mat!#[cast i,cast i] == 0 then Nothing
411 | else Just $
gaussStep i mat
421 | record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where
427 | namespace DecompLUP
430 | lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
431 | lower {m,n} (MkLUP lu _ _) with (viewShape lu)
432 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
433 | case compare i j of
440 | (.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
447 | lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
448 | lower' lu = rewrite cong (\i => Matrix n i a) $
sym (minimumIdempotent n)
455 | (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
460 | upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
461 | upper {m,n} (MkLUP lu _ _) with (viewShape lu)
462 | _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
463 | if i <= j then lu!#[i,j] else 0)
467 | (.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
474 | upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
475 | upper' lu = rewrite cong (\i => Matrix i n a) $
sym (minimumIdempotent n)
482 | (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
487 | permute : DecompLUP {m} mat -> Permutation m
488 | permute (MkLUP lu p sw) = p
492 | (.permute) : DecompLUP {m} mat -> Permutation m
493 | (.permute) = permute
500 | numSwaps : DecompLUP mat -> Nat
501 | numSwaps (MkLUP lu p sw) = sw
506 | fromLU : DecompLU mat -> DecompLUP mat
507 | fromLU (MkLU lu) = MkLUP lu identity 0
511 | decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat
512 | decompLUP {m,n} mat with (viewShape mat)
513 | decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
514 | decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
515 | decompLUP {m=S m,n=S n} mat | Shape [S m,S n] =
516 | iterateN (S $
minimum m n) gaussStepSwap (MkLUP mat identity 0)
518 | maxIndex : (s,a) -> List (s,a) -> (s,a)
521 | maxIndex x ((a,b)::(c,d)::xs) =
522 | if abslt b d then maxIndex x ((c,d)::xs)
523 | else assert_total $
maxIndex x ((a,b)::xs)
525 | gaussStepSwap : Fin (S $
minimum m n) -> DecompLUP mat -> DecompLUP mat
526 | gaussStepSwap i (MkLUP lu p sw) =
527 | let ir = minWeakenLeft {n=S n} i
528 | ic = minWeakenRight {m=S m} i
529 | maxi = head $
fst (maxIndex ([0],0) $
drop (cast i) $
enumerate $
530 | indexSetRange [EndBound (weaken ir)] 0 $
getColumn ic lu)
531 | in if maxi == ir then MkLUP (gaussStep i lu) p sw
532 | else MkLUP (gaussStep i $
swapRows ir maxi lu) (appendSwap maxi ir p) (S sw)
542 | detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a
543 | detWithLUP mat lup =
544 | (if numSwaps lup `mod` 2 == 0 then 1 else -
1)
545 | * product (diagonal lup.lu)
549 | det : FieldCmp a => Matrix' n a -> a
550 | det {n} mat with (viewShape mat)
551 | det {n=0} mat | Shape [0,0] = 1
552 | det {n=1} mat | Shape [1,1] = mat!![0,0]
553 | det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
554 | _ | Shape [n,n] = detWithLUP mat (decompLUP mat)
562 | solveLowerTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
563 | solveLowerTri' {n} mat b with (viewShape b)
564 | _ | Shape [n] = vector $
reverse $
construct $
reverse $
toVect b
566 | construct : {i : _} -> Vect i a -> Vect i a
568 | construct {i=S i} (b :: bs) =
569 | let xs = construct bs
570 | in (b - sum (zipWith (*) xs (reverse $
toVect $
replace {p = flip Array a} (believe_me $
Refl {x=()}) $
571 | mat !#.. [One i, EndBound i]))) / mat!#[i,i] :: xs
574 | solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
575 | solveUpperTri' {n} mat b with (viewShape b)
576 | _ | Shape [n] = vector $
construct Z $
toVect b
578 | construct : Nat -> Vect i a -> Vect i a
579 | construct _ [] = []
580 | construct i (b :: bs) =
581 | let xs = construct (S i) bs
582 | in (b - sum (zipWith (*) xs (toVect $
replace {p = flip Array a} (believe_me $
Refl {x=()}) $
583 | mat !#.. [One i, StartBound (S i)]))) / mat!#[i,i] :: xs
589 | solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
590 | solveLowerTri mat b = if all (/=0) (diagonal mat)
591 | then Just $
solveLowerTri' mat b
597 | solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
598 | solveUpperTri mat b = if all (/=0) (diagonal mat)
599 | then Just $
solveUpperTri' mat b
603 | solveWithLUP' : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
604 | Vector n a -> Vector n a
605 | solveWithLUP' mat lup b =
606 | let b' = permuteCoords (inverse lup.permute) b
607 | in solveUpperTri' lup.upper' $
solveLowerTri' lup.lower' b'
611 | solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
612 | Vector n a -> Maybe (Vector n a)
613 | solveWithLUP mat lup b =
614 | let b' = permuteCoords (inverse lup.permute) b
615 | in solveUpperTri lup.upper' $
solveLowerTri' lup.lower' b'
619 | solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
620 | solve mat = solveWithLUP mat (decompLUP mat)
630 | invertible : FieldCmp a => Matrix' n a -> Bool
631 | invertible {n} mat with (viewShape mat)
632 | _ | Shape [n,n] = let lup = decompLUP mat in all (/=0) (diagonal lup.lu)
637 | tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a)
638 | tryInverse {n} mat with (viewShape mat)
640 | let lup = decompLUP mat
641 | in map hstack $
traverse (solveWithLUP mat lup) $
map basis range
645 | {n : _} -> FieldCmp a => MultGroup (Matrix' n a) where
646 | inverse mat = let lup = decompLUP mat in
647 | hstack $
map (solveWithLUP' mat lup . basis) range