0 | module Libraries.Data.SparseMatrix
  1 |
  2 | import Algebra.Semiring
  3 |
  4 | import Data.List1
  5 | import Data.Maybe
  6 | import Data.String
  7 |
  8 | import Libraries.Text.PrettyPrint.Prettyprinter
  9 |
 10 | %default total
 11 |
 12 | namespace Vector
 13 |   ||| A sparse vector is a list of pairs consisting of an index and its
 14 |   ||| corresponding element.
 15 |   |||
 16 |   ||| Invariants:
 17 |   ||| - indices must appear in order and should to be duplicate-free,
 18 |   ||| - elements must be additively non-neutral,
 19 |   ||| - missing entries are assumed to be neutral.
 20 |   public export
 21 |   Vector : Type -> Type
 22 |   Vector a = List (Nat, a)
 23 |
 24 |   ||| Turns a list into a sparse vector, discarding neutral elements in
 25 |   ||| the process.
 26 |   export
 27 |   fromList : (Eq a, Semiring a) => List a -> Vector a
 28 |   fromList = go Z
 29 |     where
 30 |       go : Nat -> List a -> Vector a
 31 |       go i [] = []
 32 |       go i (x :: xs)
 33 |           = if x == plusNeutral
 34 |                then go (S i) xs
 35 |                else (i, x) :: go (S i) xs
 36 |
 37 |   ||| Insert `x` at index `i`. Ignore if the `i`th element already
 38 |   ||| exists.
 39 |   |||
 40 |   ||| @ x must not be neutral
 41 |   export
 42 |   insert : (i : Nat) -> (x : a) -> Vector a -> Vector a
 43 |   insert i x [] = [(i,x)]
 44 |   insert i x ys@((j, y) :: ys') =
 45 |     case compare i j of
 46 |       LT => (i, x) :: ys
 47 |       EQ => ys -- keep
 48 |       GT => (j, y) :: insert i x ys'
 49 |
 50 |   export
 51 |   lookupOrd : Ord k => k -> List (k, a) -> Maybe a
 52 |   lookupOrd i [] = Nothing
 53 |   lookupOrd i ((k, x) :: xs) =
 54 |     case compare i k of
 55 |       LT => Nothing
 56 |       EQ => Just x
 57 |       GT => lookupOrd i xs
 58 |
 59 |   export
 60 |   maxIndex : Vector a -> Maybe Nat
 61 |   maxIndex xs = map fst (last' xs)
 62 |
 63 |   export
 64 |   length : Vector a -> Nat
 65 |   length = maybe 0 ((+) 1) . maxIndex
 66 |
 67 |   export
 68 |   dot : Semiring a => Vector a -> Vector a -> a
 69 |   dot [] _ = plusNeutral
 70 |   dot _ [] = plusNeutral
 71 |   dot xs@((k, x) :: xs') ys@((k', y) :: ys') =
 72 |     case compare k k' of
 73 |       LT => dot xs' ys
 74 |       EQ => (x |*| y) |+| dot xs' ys'
 75 |       GT => dot xs ys'
 76 |
 77 | namespace Vector1
 78 |   public export
 79 |   Vector1 : Type -> Type
 80 |   Vector1 a = List1 (Nat, a)
 81 |
 82 |   export
 83 |   fromList : (Eq a, Semiring a) => List a -> Maybe (Vector1 a)
 84 |   fromList = Data.List1.fromList . Vector.fromList
 85 |
 86 |   export
 87 |   insert : Nat -> a -> Vector1 a -> Vector1 a
 88 |   insert i x ys@((j, y) ::: ys') =
 89 |     case compare i j of
 90 |       LT => (i, x) ::: (j, y) :: ys'
 91 |       EQ => ys -- keep
 92 |       GT => (j, y) ::: insert i x ys'
 93 |
 94 |   export
 95 |   lookupOrd : Ord k => k -> List1 (k, a) -> Maybe a
 96 |   lookupOrd i ((k, x) ::: xs) =
 97 |     case compare i k of
 98 |       LT => Nothing
 99 |       EQ => Just x
100 |       GT => lookupOrd i xs
101 |
102 |   export
103 |   maxIndex : Vector1 a -> Nat
104 |   maxIndex ((i, x) ::: xs) = maybe i fst (last' xs)
105 |
106 | ||| A sparse matrix is a sparse vector of (non-empty) sparse vectors.
107 | public export
108 | Matrix : Type -> Type
109 | Matrix a = Vector (Vector1 a)
110 |
111 | export
112 | fromListList : (Eq a, Semiring a) => List (List a) -> Matrix a
113 | fromListList = mapMaybe (\(i, xs) => map (i,) (Vector1.fromList xs)) . withIndex
114 |   where
115 |     -- may contain empty lists
116 |     withIndex : List (List a) -> List (Nat, List a)
117 |     withIndex = go Z
118 |       where
119 |         go : Nat -> List (List a) -> List (Nat, List a)
120 |         go i [] = []
121 |         go i (x :: xs) = (i, x) :: go (S i) xs
122 |
123 | export
124 | transpose : Matrix a -> Matrix a
125 | transpose [] = []
126 | transpose ((i, xs) :: xss) = spreadHeads i (toList xs) (transpose xss) where
127 |   spreadHeads : Nat -> Vector a -> Matrix a -> Matrix a
128 |   spreadHeads i [] yss = yss
129 |   spreadHeads i xs [] = map (\(j,x) => (j, singleton (i,x))) xs
130 |   spreadHeads i xs@((j, x) :: xs') yss@((j', ys) :: yss') =
131 |     case compare j j' of
132 |       LT => (j, singleton (i,x)) :: spreadHeads i xs' yss
133 |       EQ => (j', insert i x ys) :: spreadHeads i xs' yss'
134 |       GT => (j', ys) :: spreadHeads i xs yss'
135 |
136 | multRow : (Eq a, Semiring a) => Matrix a -> Vector1 a -> Vector a
137 | multRow [] ys = []
138 | multRow ((i, xs) :: xss) ys =
139 |   let z = dot (toList xs) (toList ys) in
140 |   if z == plusNeutral then
141 |     multRow xss ys
142 |   else
143 |     (i, z) :: multRow xss ys
144 |
145 | ||| Given matrices `xss` and `yss`, computes `xss^T * yss`.
146 | multTranspose : (Eq a, Semiring a) => (xss : Matrix a) -> (yss : Matrix a) -> Matrix a
147 | multTranspose xss [] = []
148 | multTranspose xss ((j, ys) :: yss) =
149 |   case multRow xss ys of
150 |     [] => multTranspose xss yss -- discard empty rows
151 |     y' :: ys' => (j, y' ::: ys') :: multTranspose xss yss
152 |
153 | export
154 | mult : (Eq a, Semiring a) => Matrix a -> Matrix a -> Matrix a
155 | mult = multTranspose . transpose
156 |
157 | ||| Find largest column index.
158 | maxColumnIndex : Matrix a -> Maybe Nat
159 | maxColumnIndex = foldMap @{%search} @{Monoid.Deep @{MkSemigroup max}} (Just . maxIndex . snd)
160 |
161 | namespace Pretty
162 |   header : (columnWidth : Int) -> (length : Nat) -> List (Doc ann)
163 |   header columnWidth length = map (fill columnWidth . byShow) (take length [0..])
164 |
165 |   row : Pretty ann a => Vector a -> List (Doc ann)
166 |   row xs = go rowLength xs
167 |     where
168 |       rowLength : Nat
169 |       rowLength = Vector.length xs
170 |
171 |       prettyNeutral : Doc ann
172 |       prettyNeutral = space
173 |
174 |       go : Nat -> Vector a -> List (Doc ann)
175 |       go _ [] = []
176 |       go Z _ = []
177 |       go i@(S i') ys@((j, y) :: ys') =
178 |         case compare (minus rowLength i) j of
179 |           LT => prettyNeutral :: go i' ys
180 |           EQ => pretty y :: go i' ys'
181 |           GT => go i ys'
182 |
183 |   -- space between columns
184 |   columnSpacing : Int
185 |   columnSpacing = 1
186 |
187 |   columnSep : List (Doc ann) -> Doc ann
188 |   columnSep = concatWith (\x, y => x <+> spaces columnSpacing <+> y)
189 |
190 |   row1 : Pretty ann a => (columnWidth : Int) -> Vector1 a -> List (Doc ann)
191 |   row1 columnWidth ys = map (fill columnWidth) (row (toList ys))
192 |
193 |   hardlineSep : List (Doc ann) -> Doc ann
194 |   hardlineSep = concatWith (\x, y => x <+> hardline <+> y)
195 |
196 |   ||| Renders a matrix as an ASCII table of the following shape:
197 |   |||
198 |   ||| ```
199 |   |||                columnSpacing
200 |   |||                      __
201 |   |||         columnDesc  0  1  ...
202 |   ||| rowDesc +--------------------
203 |   ||| 0       |           a  b
204 |   ||| 1       |           c  d
205 |   ||| ...     |           ...
206 |   |||```
207 |   |||
208 |   ||| Note: everything is sadly left-aligned.
209 |   |||
210 |   ||| @ maxWidthA Maximal length rendering something of type `a` might reach.
211 |   export
212 |   prettyTable : Pretty ann a => (rowDesc, columnDesc : String) -> (maxWidthA : Nat) -> Matrix a -> Doc ann
213 |   prettyTable rowDesc columnDesc maxWidthA m = hardlineSep $
214 |       -- header
215 |       (spaces rowLabelWidth <++> columnSep (pretty0 columnDesc :: header columnWidth columnCount))
216 |       -- separator
217 |         :: (fill rowLabelWidth (pretty0 rowDesc) <++> Chara intersectionLabelSep <+> replicateChar columnLabelSepLength columnLabelSep)
218 |       -- content
219 |         :: map (\(j, r) => fill rowLabelWidth (byShow j) <++> columnSep (fill (cast (length columnDesc)) (Chara rowLabelSep) :: row1 columnWidth r)) m
220 |     where
221 |       rowLabelSep, columnLabelSep, intersectionLabelSep : Char
222 |       rowLabelSep = '|'
223 |       columnLabelSep = '-'
224 |       intersectionLabelSep = '+'
225 |
226 |       rowMax, columnMax : Maybe Nat
227 |       rowMax = maxIndex m
228 |       columnMax = maxColumnIndex m
229 |
230 |       rowMaxIndexWidth, columnMaxIndexWidth : Nat
231 |       rowMaxIndexWidth = maybe 0 (length . show) rowMax
232 |       columnMaxIndexWidth = maybe 0 (length . show) columnMax
233 |
234 |       rowLabelWidth : Int
235 |       rowLabelWidth = cast $ max rowMaxIndexWidth (length rowDesc)
236 |
237 |       columnWidth : Int
238 |       columnWidth = cast $ max columnMaxIndexWidth maxWidthA
239 |
240 |       columnCount : Nat
241 |       columnCount = maybe 0 ((+) 1) columnMax
242 |
243 |       columnLabelSepLength : Int
244 |       columnLabelSepLength =
245 |         cast (minus (length columnDesc) 1) -- columnDesc overlaps with *LabelSep
246 |         + (columnWidth + columnSpacing) * cast columnCount
247 |