0 | module Libraries.Data.SparseMatrix
2 | import Algebra.Semiring
8 | import Libraries.Text.PrettyPrint.Prettyprinter
21 | Vector : Type -> Type
22 | Vector a = List (Nat, a)
27 | fromList : (Eq a, Semiring a) => List a -> Vector a
30 | go : Nat -> List a -> Vector a
33 | = if x == plusNeutral
35 | else (i, x) :: go (S i) xs
42 | insert : (i : Nat) -> (x : a) -> Vector a -> Vector a
43 | insert i x [] = [(i,x)]
44 | insert i x ys@((j, y) :: ys') =
48 | GT => (j, y) :: insert i x ys'
51 | lookupOrd : Ord k => k -> List (k, a) -> Maybe a
52 | lookupOrd i [] = Nothing
53 | lookupOrd i ((k, x) :: xs) =
57 | GT => lookupOrd i xs
60 | maxIndex : Vector a -> Maybe Nat
61 | maxIndex xs = map fst (last' xs)
64 | length : Vector a -> Nat
65 | length = maybe 0 ((+) 1) . maxIndex
68 | dot : Semiring a => Vector a -> Vector a -> a
69 | dot [] _ = plusNeutral
70 | dot _ [] = plusNeutral
71 | dot xs@((k, x) :: xs') ys@((k', y) :: ys') =
72 | case compare k k' of
74 | EQ => (x |*| y) |+| dot xs' ys'
79 | Vector1 : Type -> Type
80 | Vector1 a = List1 (Nat, a)
83 | fromList : (Eq a, Semiring a) => List a -> Maybe (Vector1 a)
84 | fromList = Data.List1.fromList . Vector.fromList
87 | insert : Nat -> a -> Vector1 a -> Vector1 a
88 | insert i x ys@((j, y) ::: ys') =
90 | LT => (i, x) ::: (j, y) :: ys'
92 | GT => (j, y) ::: insert i x ys'
95 | lookupOrd : Ord k => k -> List1 (k, a) -> Maybe a
96 | lookupOrd i ((k, x) ::: xs) =
100 | GT => lookupOrd i xs
103 | maxIndex : Vector1 a -> Nat
104 | maxIndex ((i, x) ::: xs) = maybe i fst (last' xs)
108 | Matrix : Type -> Type
109 | Matrix a = Vector (Vector1 a)
112 | fromListList : (Eq a, Semiring a) => List (List a) -> Matrix a
113 | fromListList = mapMaybe (\(i, xs) => map (i,) (Vector1.fromList xs)) . withIndex
116 | withIndex : List (List a) -> List (Nat, List a)
119 | go : Nat -> List (List a) -> List (Nat, List a)
121 | go i (x :: xs) = (i, x) :: go (S i) xs
124 | transpose : Matrix a -> Matrix a
126 | transpose ((i, xs) :: xss) = spreadHeads i (toList xs) (transpose xss) where
127 | spreadHeads : Nat -> Vector a -> Matrix a -> Matrix a
128 | spreadHeads i [] yss = yss
129 | spreadHeads i xs [] = map (\(j,x) => (j, singleton (i,x))) xs
130 | spreadHeads i xs@((j, x) :: xs') yss@((j', ys) :: yss') =
131 | case compare j j' of
132 | LT => (j, singleton (i,x)) :: spreadHeads i xs' yss
133 | EQ => (j', insert i x ys) :: spreadHeads i xs' yss'
134 | GT => (j', ys) :: spreadHeads i xs yss'
136 | multRow : (Eq a, Semiring a) => Matrix a -> Vector1 a -> Vector a
138 | multRow ((i, xs) :: xss) ys =
139 | let z = dot (toList xs) (toList ys) in
140 | if z == plusNeutral then
143 | (i, z) :: multRow xss ys
146 | multTranspose : (Eq a, Semiring a) => (xss : Matrix a) -> (yss : Matrix a) -> Matrix a
147 | multTranspose xss [] = []
148 | multTranspose xss ((j, ys) :: yss) =
149 | case multRow xss ys of
150 | [] => multTranspose xss yss
151 | y' :: ys' => (j, y' ::: ys') :: multTranspose xss yss
154 | mult : (Eq a, Semiring a) => Matrix a -> Matrix a -> Matrix a
155 | mult = multTranspose . transpose
158 | maxColumnIndex : Matrix a -> Maybe Nat
159 | maxColumnIndex = foldMap @{%search} @{Monoid.Deep @{MkSemigroup max}} (Just . maxIndex . snd)
162 | header : (columnWidth : Int) -> (length : Nat) -> List (Doc ann)
163 | header columnWidth length = map (fill columnWidth . byShow) (take length [0..])
165 | row : Pretty ann a => Vector a -> List (Doc ann)
166 | row xs = go rowLength xs
169 | rowLength = Vector.length xs
171 | prettyNeutral : Doc ann
172 | prettyNeutral = space
174 | go : Nat -> Vector a -> List (Doc ann)
177 | go i@(S i') ys@((j, y) :: ys') =
178 | case compare (minus rowLength i) j of
179 | LT => prettyNeutral :: go i' ys
180 | EQ => pretty y :: go i' ys'
184 | columnSpacing : Int
187 | columnSep : List (Doc ann) -> Doc ann
188 | columnSep = concatWith (\x, y => x <+> spaces columnSpacing <+> y)
190 | row1 : Pretty ann a => (columnWidth : Int) -> Vector1 a -> List (Doc ann)
191 | row1 columnWidth ys = map (fill columnWidth) (row (toList ys))
193 | hardlineSep : List (Doc ann) -> Doc ann
194 | hardlineSep = concatWith (\x, y => x <+> hardline <+> y)
212 | prettyTable : Pretty ann a => (rowDesc, columnDesc : String) -> (maxWidthA : Nat) -> Matrix a -> Doc ann
213 | prettyTable rowDesc columnDesc maxWidthA m = hardlineSep $
215 | (spaces rowLabelWidth <++> columnSep (pretty0 columnDesc :: header columnWidth columnCount))
217 | :: (fill rowLabelWidth (pretty0 rowDesc) <++> Chara intersectionLabelSep <+> replicateChar columnLabelSepLength columnLabelSep)
219 | :: map (\(j, r) => fill rowLabelWidth (byShow j) <++> columnSep (fill (cast (length columnDesc)) (Chara rowLabelSep) :: row1 columnWidth r)) m
221 | rowLabelSep, columnLabelSep, intersectionLabelSep : Char
223 | columnLabelSep = '-'
224 | intersectionLabelSep = '+'
226 | rowMax, columnMax : Maybe Nat
227 | rowMax = maxIndex m
228 | columnMax = maxColumnIndex m
230 | rowMaxIndexWidth, columnMaxIndexWidth : Nat
231 | rowMaxIndexWidth = maybe 0 (length . show) rowMax
232 | columnMaxIndexWidth = maybe 0 (length . show) columnMax
234 | rowLabelWidth : Int
235 | rowLabelWidth = cast $
max rowMaxIndexWidth (length rowDesc)
238 | columnWidth = cast $
max columnMaxIndexWidth maxWidthA
241 | columnCount = maybe 0 ((+) 1) columnMax
243 | columnLabelSepLength : Int
244 | columnLabelSepLength =
245 | cast (minus (length columnDesc) 1)
246 | + (columnWidth + columnSpacing) * cast columnCount