0 | module Algebra.Solver.Semiring.Sum
  1 |
  2 | import Algebra.Solver.Semiring.Expr
  3 | import Algebra.Solver.Semiring.Prod
  4 | import Algebra.Solver.Semiring.SolvableSemiring
  5 | import Algebra.Solver.Semiring.Util
  6 |
  7 | %default total
  8 |
  9 | ||| A single term in a normalized arithmetic expressions.
 10 | |||
 11 | ||| This is a product of all variables each raised to
 12 | ||| a given power, multiplied with a factors (which is supposed
 13 | ||| to reduce during elaboration).
 14 | public export
 15 | record Term (a : Type) (as : List a) where
 16 |   constructor T
 17 |   factor : a
 18 |   prod   : Prod a as
 19 |
 20 | ||| Evaluate a term.
 21 | public export
 22 | eterm : Semiring a => {as : List a} -> Term a as -> a
 23 | eterm (T f p) = f * eprod p
 24 |
 25 | ||| Normalized arithmetic expression in a commutative
 26 | ||| ring (represented as an (ordered) sum of terms).
 27 | public export
 28 | data Sum : (a : Type) -> (as : List a) -> Type where
 29 |   Nil  : {0 as : List a} -> Sum a as
 30 |   (::) : {0 as : List a} -> Term a as -> Sum a as -> Sum a as
 31 |
 32 | ||| Evaluate a sum of terms.
 33 | public export
 34 | esum : Semiring a => {as : List a} -> Sum a as -> a
 35 | esum []        = 0
 36 | esum (x :: xs) = eterm x + esum xs
 37 |
 38 | --------------------------------------------------------------------------------
 39 | --          Normalizer
 40 | --------------------------------------------------------------------------------
 41 |
 42 | ||| Add two sums of terms.
 43 | |||
 44 | ||| The order of terms will be kept. If two terms have identical
 45 | ||| products of variables, they will be unified by adding their
 46 | ||| factors.
 47 | public export
 48 | add : SolvableSemiring a => Sum a as -> Sum a as -> Sum a as
 49 | add []        ys                = ys
 50 | add xs        []                = xs
 51 | add (T m x :: xs) (T n y :: ys) = case compProd x y of
 52 |   LT => T m x :: add xs (T n y :: ys)
 53 |   GT => T n y :: add (T m x :: xs) ys
 54 |   EQ => T (m + n) y :: add xs ys
 55 |
 56 | ||| Normalize a sum of terms by removing all terms with a
 57 | ||| `zero` factor.
 58 | public export
 59 | normSum : SolvableSemiring a => Sum a as -> Sum a as
 60 | normSum []           = []
 61 | normSum (T f p :: y) = case isZero f of
 62 |   Just refl => normSum y
 63 |   Nothing   => T f p :: normSum y
 64 |
 65 | ||| Multiplies a single term with a sum of terms.
 66 | public export
 67 | mult1 : SolvableSemiring a => Term a as -> Sum a as -> Sum a as
 68 | mult1 (T f p) (T g q :: xs) = T (f * g) (mult p q) :: mult1 (T f p) xs
 69 | mult1 _       []            = []
 70 |
 71 | ||| Multiplies two sums of terms.
 72 | public export
 73 | mult : SolvableSemiring a => Sum a as -> Sum a as -> Sum a as
 74 | mult []        ys = []
 75 | mult (x :: xs) ys = add (mult1 x ys) (mult xs ys)
 76 |
 77 | ||| Normalizes an arithmetic expression to a sum of products.
 78 | public export
 79 | norm : SolvableSemiring a => {as : List a} -> Expr a as -> Sum a as
 80 | norm (Lit n)     = [T n one]
 81 | norm (Var x y)   = [T 1 $ fromVar y]
 82 | norm (Plus x y)  = add (norm x) (norm y)
 83 | norm (Mult x y)  = mult (norm x) (norm y)
 84 |
 85 | ||| Like `norm` but removes all `zero` terms.
 86 | public export
 87 | normalize : SolvableSemiring a => {as : List a} -> Expr a as -> Sum a as
 88 | normalize e = normSum (norm e)
 89 |
 90 | --------------------------------------------------------------------------------
 91 | --          Proofs
 92 | --------------------------------------------------------------------------------
 93 |
 94 | -- Adding two sums via `add` preserves the evaluation result.
 95 | -- Note: `assert_total` in here is a temporary fix for idris issue #2954
 96 | 0 padd :
 97 |      {auto _ : SolvableSemiring a}
 98 |   -> (x,y : Sum a as)
 99 |   -> esum x + esum y === esum (add x y)
100 | padd []            xs = plusZeroLeftNeutral
101 | padd (x :: xs)     [] = plusZeroRightNeutral
102 | padd (T m x :: xs) (T n y :: ys) with (compProd x y) proof eq
103 |   _ | LT = Calc $
104 |     |~ (m * eprod x + esum xs) + (n * eprod y + esum ys)
105 |     ~~ m * eprod x + (esum xs + (n * eprod y + esum ys))
106 |        ..< plusAssociative
107 |     ~~ m * eprod x + esum (add xs (T n y :: ys))
108 |        ... cong (m * eprod x +) (padd xs (T n y :: ys))
109 |
110 |   _ | GT = Calc $
111 |     |~ (m * eprod x + esum xs) + (n * eprod y + esum ys)
112 |     ~~ n * eprod y + ((m * eprod x + esum xs) + esum ys)
113 |        ..< p213
114 |     ~~ n * eprod y + esum (add (T m x :: xs) ys)
115 |        ... cong (n * eprod y +) (assert_total $ padd (T m x :: xs) ys)
116 |
117 |   _ | EQ = case pcompProd x y eq of
118 |         Refl => Calc $
119 |           |~ (m * eprod x + esum xs) + (n * eprod x + esum ys)
120 |           ~~ (m * eprod x + n * eprod x) + (esum xs + esum ys)
121 |              ... p1324
122 |           ~~ (m + n) * eprod x + (esum xs + esum ys)
123 |              ..< cong (+ (esum xs + esum ys)) rightDistributive
124 |           ~~ (m + n) * eprod x + esum (add xs ys)
125 |              ... cong ((m + n) * eprod x +) (padd xs ys)
126 |
127 | -- Small utility lemma
128 | 0 psum0 :
129 |      {auto _ : SolvableSemiring a}
130 |   -> {x,y,z : a}
131 |   -> x === y
132 |   -> x === 0 * z + y
133 | psum0 prf = Calc $
134 |   |~ x
135 |   ~~ y          ... prf
136 |   ~~ 0 + y      ..< plusZeroLeftNeutral
137 |   ~~ 0 * z + y  ..< cong (+ y) multZeroLeftAbsorbs
138 |
139 | -- Multiplying a sum with a term preserves the evaluation result.
140 | 0 pmult1 :
141 |      {auto _ : SolvableSemiring a}
142 |   -> (m : a)
143 |   -> (p : Prod a as)
144 |   -> (s : Sum a as)
145 |   -> esum (mult1 (T m p) s) === (m * eprod p) * esum s
146 | pmult1 m p []            = sym multZeroRightAbsorbs
147 | pmult1 m p (T n q :: xs) = Calc $
148 |   |~ (m * n) * (eprod (mult p q)) + esum (mult1 (T m p) xs)
149 |   ~~ (m * n) * (eprod p * eprod q) + esum (mult1 (T m p) xs)
150 |      ... cong (\x => (m*n) * x + esum (mult1 (T m p) xs)) (pmult p q)
151 |   ~~ (m * eprod p) * (n * eprod q) + esum (mult1 (T m p) xs)
152 |      ..< cong (+ esum (mult1 (T m p) xs)) m1324
153 |   ~~ (m * eprod p) * (n * eprod q) + (m * eprod p) * esum xs
154 |      ... cong ((m * eprod p) * (n * eprod q) +) (pmult1 m p xs)
155 |   ~~ (m * eprod p) * (n * eprod q + esum xs)
156 |      ..< leftDistributive
157 |
158 | -- Multiplying two sums of terms preserves the evaluation result.
159 | 0 pmult :
160 |      {auto _ : SolvableSemiring a}
161 |   -> (x,y : Sum a as)
162 |   -> esum x * esum y === esum (mult x y)
163 | pmult []            y = multZeroLeftAbsorbs
164 | pmult (T n x :: xs) y = Calc $
165 |   |~ (n * eprod x + esum xs) * esum y
166 |   ~~ (n * eprod x) * esum y + esum xs * esum y
167 |      ... rightDistributive
168 |   ~~ (n * eprod x) * esum y + esum (mult xs y)
169 |      ... cong ((n * eprod x) * esum y +) (pmult xs y)
170 |   ~~ esum (mult1 (T n x) y) + esum (mult xs y)
171 |      ..< cong (+ esum (mult xs y)) (pmult1 n x y)
172 |   ~~ esum (add (mult1 (T n x) y) (mult xs y))
173 |      ... padd (mult1 (T n x) y) (mult xs y)
174 |
175 | -- Removing zero values from a sum of terms does not
176 | -- affect the evaluation result.
177 | 0 pnormSum :
178 |      {auto _ : SolvableSemiring a}
179 |   -> (s : Sum a as)
180 |   -> esum (normSum s) === esum s
181 | pnormSum []           = Refl
182 | pnormSum (T f p :: y) with (isZero f)
183 |   _ | Nothing   = Calc $
184 |     |~ f * eprod p + esum (normSum y)
185 |     ~~ f * eprod p + esum y ... cong ((f * eprod p) +) (pnormSum y)
186 |
187 |   _ | Just refl = Calc $
188 |     |~ esum (normSum y)
189 |     ~~ esum y               ... pnormSum y
190 |     ~~ 0 + esum y           ..< plusZeroLeftNeutral
191 |     ~~ 0 * eprod p + esum y ..< cong (+ esum y) multZeroLeftAbsorbs
192 |     ~~ f * eprod p + esum y ..< cong (\x => x * eprod p + esum y) refl
193 |
194 | -- Evaluating an expression gives the same result as
195 | -- evaluating its normalized form.
196 | 0 pnorm :
197 |      {auto _ : SolvableSemiring a}
198 |   -> (e : Expr a as)
199 |   -> eval e === esum (norm e)
200 | pnorm (Lit n)    = Calc $
201 |   |~ n
202 |   ~~ n * 1                    ..< multOneRightNeutral
203 |   ~~ n * eprod (one {as})     ..< cong (n *) (pone as)
204 |   ~~ n * eprod (one {as}) + 0 ..< plusZeroRightNeutral
205 |
206 | pnorm (Var x y)  = Calc $
207 |   |~ x
208 |   ~~ eprod (fromVar y)          ..< pvar as y
209 |   ~~ 1 * eprod (fromVar y)      ..< multOneLeftNeutral
210 |   ~~ 1 * eprod (fromVar y) + 0  ..< plusZeroRightNeutral
211 |
212 | pnorm (Plus x y) = Calc $
213 |   |~ eval x + eval y
214 |   ~~ esum (norm x) + eval y
215 |      ... cong (+ eval y) (pnorm x)
216 |   ~~ esum (norm x) + esum (norm y)
217 |      ... cong (esum (norm x) +) (pnorm y)
218 |   ~~ esum (add (norm x) (norm y))
219 |      ... padd _ _
220 |
221 | pnorm (Mult x y) = Calc $
222 |   |~ eval x * eval y
223 |   ~~ esum (norm x) * eval y
224 |      ... cong (* eval y) (pnorm x)
225 |   ~~ esum (norm x) * esum (norm y)
226 |      ... cong (esum (norm x) *) (pnorm y)
227 |   ~~ esum (mult (norm x) (norm y))
228 |      ... Sum.pmult _ _
229 |
230 | -- Evaluating an expression gives the same result as
231 | -- evaluating its normalized form.
232 | 0 pnormalize :
233 |      {auto _ : SolvableSemiring a}
234 |   -> (e : Expr a as)
235 |   -> eval e === esum (normalize e)
236 | pnormalize e = Calc $
237 |   |~ eval e
238 |   ~~ esum (norm e)           ... pnorm e
239 |   ~~ esum (normSum (norm e)) ..< pnormSum (norm e)
240 |
241 | --------------------------------------------------------------------------------
242 | --          Solver
243 | --------------------------------------------------------------------------------
244 |
245 | ||| Given a list `as` of variables and two arithmetic expressions
246 | ||| over these variables, if the expressions are converted
247 | ||| to the same normal form, evaluating them gives the same
248 | ||| result.
249 | |||
250 | ||| This simple fact allows us to conveniently and quickly
251 | ||| proof arithmetic equalities required in other parts of
252 | ||| our code. For instance:
253 | |||
254 | ||| ```idris example
255 | ||| 0 binom1 : {x,y : Bits8}
256 | |||          -> (x + y) * (x + y) === x * x + 2 * x * y + y * y
257 | ||| binom1 = solve [x,y]
258 | |||                ((x .+. y) * (x .+. y))
259 | |||                (x .*. x + 2 *. x *. y + y .*. y)
260 | ||| ```
261 | export
262 | 0 solve :
263 |      {auto _ : SolvableSemiring a}
264 |   -> (as : List a)
265 |   -> (e1,e2 : Expr a as)
266 |   -> {auto prf : normalize e1 === normalize e2}
267 |   -> eval e1 === eval e2
268 | solve _ e1 e2 = Calc $
269 |   |~ eval e1
270 |   ~~ esum (normalize e1) ...(pnormalize e1)
271 |   ~~ esum (normalize e2) ...(cong esum prf)
272 |   ~~ eval e2             ..<(pnormalize e2)
273 |