0 | module Data.Matrix
  1 |
  2 | import Data.Vect
  3 | import Data.String
  4 | import Data.Zippable
  5 | import Data.Maybe
  6 |
  7 | export
  8 | data Matrix : Nat -> Nat -> Type -> Type where
  9 |   MkMatrix : Vect rows (Vect cols a) -> Matrix rows cols a
 10 |
 11 | public export
 12 | Eq a => Eq (Matrix rows cols a) where
 13 |   (MkMatrix lv) == (MkMatrix rv) = lv == rv
 14 |
 15 | public export
 16 | Functor (Matrix rows cols) where
 17 |   map f (MkMatrix ma) = MkMatrix $ map (map f) ma
 18 |
 19 | public export
 20 | Zippable (Matrix rows cols) where
 21 |   zipWith f (MkMatrix lv) (MkMatrix rv)  = MkMatrix $ zipWith (zipWith f) lv rv
 22 |   unzipWith f v  = let ma = (map f v) in (fst <$> ma,snd <$> ma)
 23 |
 24 |   zipWith3 f (MkMatrix a) (MkMatrix b)  (MkMatrix c) = MkMatrix $ zipWith3 (zipWith3 f) a b c
 25 |   unzipWith3 f v  = let ma = (map f v) in (fst <$> ma, fst . snd <$> ma, snd . snd <$> ma)
 26 |
 27 |
 28 | public export
 29 | fromVects : Vect rows (Vect cols a) -> Matrix rows cols a
 30 | fromVects = MkMatrix
 31 |
 32 |
 33 | public export
 34 | toVects : Matrix rows cols a -> Vect rows (Vect cols a) 
 35 | toVects (MkMatrix v) = v
 36 |
 37 | prettyShow : Show n => (matrix : Matrix rows cols n) -> String
 38 | prettyShow (MkMatrix matrix) = (joinBy "\n" $ map show $ toList matrix) ++ "\n"
 39 |
 40 | public export
 41 | Show a => Show (Matrix rows cols a) where
 42 |   show = prettyShow
 43 |
 44 | private 
 45 | dotProduct : Num a => Vect n a -> Vect n a -> a
 46 | dotProduct x y = sum $ zipWith (*) x y
 47 |
 48 |
 49 | public export
 50 | fill : (rows : Nat) -> (cols : Nat) -> t -> (Matrix rows cols t)
 51 | fill rows cols t0 = MkMatrix $ replicate rows $ replicate cols t0
 52 |
 53 | public export
 54 | zeros : (rows : Nat) -> (cols : Nat) -> (Matrix rows cols Int)
 55 | zeros rows cols = fill rows cols 0
 56 |
 57 | public export
 58 | ones : (rows : Nat) -> (cols : Nat) -> (Matrix rows cols Int)
 59 | ones rows cols = fill rows cols  1
 60 |
 61 | public export
 62 | updateAt : (row : Fin rows) -> (col : Fin cols) -> (updateFun: a -> a) -> (Matrix rows cols a) -> Matrix rows cols a
 63 | updateAt row col f (MkMatrix vects) = MkMatrix $ updateAt row (updateAt col f) vects
 64 |
 65 | public export
 66 | replaceAt : (row : Fin rows) -> (col : Fin cols) -> (ele) -> (Matrix rows cols ele) -> (Matrix rows cols ele)
 67 | replaceAt row col ele = updateAt row col $ const ele
 68 |
 69 | public export
 70 | findIndex : (ele -> Bool) ->  (Matrix rows cols ele) -> Maybe (Fin rows, Fin cols)
 71 | findIndex p (MkMatrix vects) = go vects
 72 |   where
 73 |     go : Vect a (Vect cols ele)  -> Maybe (Fin a, Fin cols)
 74 |     go [] = Nothing
 75 |     go (v :: vs) = case findIndex p v of
 76 |       Just fc => Just (FZ, fc)
 77 |       Nothing => case go vs of
 78 |                   Just (fa, fb) => Just (FS fa, fb)
 79 |                   Nothing => Nothing
 80 |
 81 | -- public export
 82 | mapVectWithIndex : {auto len: Nat} -> (ele -> Fin len -> b) -> Vect len ele -> Vect len b
 83 | mapVectWithIndex f as = zip range as <&> \(idx, e) => f e idx
 84 |
 85 | public export
 86 | mapWithIndex : {auto rows: Nat} -> {auto cols: Nat} -> (ele -> (Fin rows, Fin cols) -> b) ->  Matrix rows cols ele ->  Matrix rows cols b
 87 | mapWithIndex f (MkMatrix vects) = MkMatrix $ zip range vects <&> \(row, v) => zip range v <&> \(col, e) => f e (row, col)
 88 |
 89 |
 90 | filter' : (ele -> Bool) ->  (Matrix rows cols ele) -> List ele
 91 | filter' p (MkMatrix vects) = filter p lists
 92 |   where
 93 |     lists : List ele
 94 |     lists = concat . map toList $ vects
 95 |
 96 | public export
 97 | findIndices : {auto rows: Nat} -> {auto cols: Nat} ->  (ele -> Bool) ->  (Matrix rows cols ele) -> List (Fin rows, Fin cols)
 98 | findIndices p matrix =  mapMaybe id $ filter' isJust mapped
 99 |   where
100 |     mapped : Matrix rows cols (Maybe (Fin rows, Fin cols))
101 |     mapped = mapWithIndex {rows = rows} {cols = cols} (\el,(fr, fc) => if p el then (Just (fr, fc)) else Nothing) matrix
102 |
103 | -- a : Matrix rows cols String -> Matrix rows cols String
104 | public export
105 | eye : (rows : Nat) -> (cols : Nat) -> Matrix rows cols Int
106 | eye rows cols  = go (rangeFromTo 0 $ min rows cols) (zeros rows cols)
107 |   where
108 |     go : List Nat -> Matrix rows cols Int -> Matrix rows cols Int
109 |     go [] matrix = matrix
110 |     go (x :: xs) matrix = case (natToFin x rows, natToFin x cols) of
111 |                         (Just fx, Just fy) => replaceAt fx fy 1 $ go xs matrix 
112 |                         (_, _) => go xs matrix
113 |
114 | public export
115 | identity : (rows : Nat)  -> Matrix rows rows Int
116 | identity rows = eye rows rows
117 |
118 |
119 | public export
120 | repmat : Num ele => Vect n ele -> (rows : Nat) -> (cols : Nat) -> Matrix rows (cols * n) ele
121 | repmat vect rows cols = MkMatrix $ replicate rows $ concat $  Data.Vect.replicate cols vect
122 |
123 | public export
124 | (+) : Num n => (matrix1 : Matrix rows cols n) ->  (matrix2 : Matrix rows cols n)  -> Matrix rows cols n
125 | (+) = zipWith (+)
126 |
127 | public export
128 | (-) : Neg n => (matrix1 : Matrix rows cols n) ->  (matrix2 : Matrix rows cols n)  -> Matrix rows cols n
129 | (-) = zipWith (-)
130 |
131 | public export
132 | scalarMulti : Num a => a -> (matrix1 : Matrix rows cols a) -> Matrix rows cols a
133 | scalarMulti a = map (* a)
134 |
135 |
136 | public export
137 | transpose : {rows: Nat} -> {cols: Nat} ->  (matrix : Matrix rows cols n)  -> Matrix cols rows n
138 | transpose (MkMatrix vects) = MkMatrix $ Data.Vect.transpose vects
139 |
140 | -- 共轭
141 | -- 共轭转置
142 | -- 行列式
143 | -- 特征值与特征向量
144 |
145 |
146 | public export
147 | (*) : Num a => {l: Nat} -> {m: Nat} -> {n: Nat} -> (matrix1 : Matrix m l a) ->  (matrix2 : Matrix l n a)  -> Matrix m n a
148 | (*) (MkMatrix matrix1) matrix2 = let (MkMatrix matrix2') = transpose matrix2 in
149 |   MkMatrix $ matrix1 <&> \r1 => 
150 |                           matrix2' <&> \c1 => 
151 |                                         dotProduct r1 c1
152 |
153 | public export
154 | minor : (row: Fin (S n)) -> (col: Fin (S n)) -> (matrix1 : Matrix (S n) (S n) a)  -> Matrix n n a 
155 | minor row col (MkMatrix matrix1)  = fromVects . deleteAt row . map (deleteAt col) $ matrix1
156 |
157 | public export
158 | algeCofactor : Neg a => (row: Fin (S n)) -> (col: Fin (S n)) -> (matrix1 : Matrix (S n) (S n) a)  -> Matrix n n a
159 | algeCofactor row col matrix1 = minor row col matrix1 <&> (* sign)
160 |   where
161 |     sign : a
162 |     sign = if mod ((cast row) + (cast col)) 2 == 0 then 1 else -1
163 |
164 | public export
165 | trace : Num a => Matrix n n a -> a
166 | trace (MkMatrix matrix1) = sum $ diag matrix1
167 |
168 |
169 | export
170 | flipVer : Matrix m n a -> Matrix m n a
171 | flipVer (MkMatrix vects) = MkMatrix $ reverse vects
172 |
173 | export
174 | flipHor : Matrix m n a -> Matrix m n a
175 | flipHor (MkMatrix vects) = MkMatrix $ map reverse vects