0 | module Algebra.Solver.Ring.Sum
  1 |
  2 | import Algebra.Solver.Ring.Expr
  3 | import Algebra.Solver.Ring.Prod
  4 | import Algebra.Solver.Ring.SolvableRing
  5 | import Algebra.Solver.Ring.Util
  6 |
  7 | %default total
  8 |
  9 | ||| A single term in a normalized arithmetic expressions.
 10 | |||
 11 | ||| This is a product of all variables each raised to
 12 | ||| a given power, multiplied with a factors (which is supposed
 13 | ||| to reduce during elaboration).
 14 | public export
 15 | record Term (a : Type) (as : List a) where
 16 |   constructor T
 17 |   factor : a
 18 |   prod   : Prod a as
 19 |
 20 | ||| Evaluate a term.
 21 | public export
 22 | eterm : Ring a => {as : List a} -> Term a as -> a
 23 | eterm (T f p) = f * eprod p
 24 |
 25 | ||| Negate a term.
 26 | public export
 27 | negTerm : Ring a => Term a as -> Term a as
 28 | negTerm (T f p) = T (negate f) p
 29 |
 30 | ||| Normalized arithmetic expression in a commutative
 31 | ||| ring (represented as an (ordered) sum of terms).
 32 | public export
 33 | data Sum : (a : Type) -> (as : List a) -> Type where
 34 |   Nil  : {0 as : List a} -> Sum a as
 35 |   (::) : {0 as : List a} -> Term a as -> Sum a as -> Sum a as
 36 |
 37 | ||| Evaluate a sum of terms.
 38 | public export
 39 | esum : Ring a => {as : List a} -> Sum a as -> a
 40 | esum []        = 0
 41 | esum (x :: xs) = eterm x + esum xs
 42 |
 43 | ||| Negate a sum of terms.
 44 | public export
 45 | negate : Ring a => Sum a as -> Sum a as
 46 | negate []       = []
 47 | negate (x :: y) = negTerm x :: negate y
 48 |
 49 |
 50 | --------------------------------------------------------------------------------
 51 | --          Normalizer
 52 | --------------------------------------------------------------------------------
 53 |
 54 | ||| Add two sums of terms.
 55 | |||
 56 | ||| The order of terms will be kept. If two terms have identical
 57 | ||| products of variables, they will be unified by adding their
 58 | ||| factors.
 59 | public export
 60 | add : SolvableRing a => Sum a as -> Sum a as -> Sum a as
 61 | add []        ys                = ys
 62 | add xs        []                = xs
 63 | add (T m x :: xs) (T n y :: ys) = case compProd x y of
 64 |   LT => T m x :: add xs (T n y :: ys)
 65 |   GT => T n y :: add (T m x :: xs) ys
 66 |   EQ => T (m + n) y :: add xs ys
 67 |
 68 | ||| Normalize a sum of terms by removing all terms with a
 69 | ||| `zero` factor.
 70 | public export
 71 | normSum : SolvableRing a => Sum a as -> Sum a as
 72 | normSum []           = []
 73 | normSum (T f p :: y) = case isZero f of
 74 |   Just refl => normSum y
 75 |   Nothing   => T f p :: normSum y
 76 |
 77 | ||| Multiplies a single term with a sum of terms.
 78 | public export
 79 | mult1 : SolvableRing a => Term a as -> Sum a as -> Sum a as
 80 | mult1 (T f p) (T g q :: xs) = T (f * g) (mult p q) :: mult1 (T f p) xs
 81 | mult1 _       []            = []
 82 |
 83 | ||| Multiplies two sums of terms.
 84 | public export
 85 | mult : SolvableRing a => Sum a as -> Sum a as -> Sum a as
 86 | mult []        ys = []
 87 | mult (x :: xs) ys = add (mult1 x ys) (mult xs ys)
 88 |
 89 | ||| Normalizes an arithmetic expression to a sum of products.
 90 | public export
 91 | norm : SolvableRing a => {as : List a} -> Expr a as -> Sum a as
 92 | norm (Lit n)     = [T n one]
 93 | norm (Var x y)   = [T 1 $ fromVar y]
 94 | norm (Neg x)     = negate $ norm x
 95 | norm (Plus x y)  = add (norm x) (norm y)
 96 | norm (Mult x y)  = mult (norm x) (norm y)
 97 | norm (Minus x y) = add (norm x) (negate $ norm y)
 98 |
 99 | ||| Like `norm` but removes all `zero` terms.
100 | public export
101 | normalize : SolvableRing a => {as : List a} -> Expr a as -> Sum a as
102 | normalize e = normSum (norm e)
103 |
104 | --------------------------------------------------------------------------------
105 | --          Proofs
106 | --------------------------------------------------------------------------------
107 |
108 | -- Adding two sums via `add` preserves the evaluation result.
109 | -- Note: `assert_total` in here is a temporary fix for idris issue #2954
110 | 0 padd :
111 |      {auto _ : SolvableRing a}
112 |   -> (x,y : Sum a as)
113 |   -> esum x + esum y === esum (add x y)
114 | padd []            xs = plusZeroLeftNeutral
115 | padd (x :: xs)     [] = plusZeroRightNeutral
116 | padd (T m x :: xs) (T n y :: ys) with (compProd x y) proof eq
117 |   _ | LT = Calc $
118 |     |~ (m * eprod x + esum xs) + (n * eprod y + esum ys)
119 |     ~~ m * eprod x + (esum xs + (n * eprod y + esum ys))
120 |        ..< plusAssociative
121 |     ~~ m * eprod x + esum (add xs (T n y :: ys))
122 |        ... cong (m * eprod x +) (padd xs (T n y :: ys))
123 |
124 |   _ | GT = Calc $
125 |     |~ (m * eprod x + esum xs) + (n * eprod y + esum ys)
126 |     ~~ n * eprod y + ((m * eprod x + esum xs) + esum ys)
127 |        ..< p213
128 |     ~~ n * eprod y + esum (add (T m x :: xs) ys)
129 |        ... cong (n * eprod y +) (assert_total $ padd (T m x :: xs) ys)
130 |
131 |   _ | EQ = case pcompProd x y eq of
132 |         Refl => Calc $
133 |           |~ (m * eprod x + esum xs) + (n * eprod x + esum ys)
134 |           ~~ (m * eprod x + n * eprod x) + (esum xs + esum ys)
135 |              ... p1324
136 |           ~~ (m + n) * eprod x + (esum xs + esum ys)
137 |              ..< cong (+ (esum xs + esum ys)) rightDistributive
138 |           ~~ (m + n) * eprod x + esum (add xs ys)
139 |              ... cong ((m + n) * eprod x +) (padd xs ys)
140 |
141 | -- Small utility lemma
142 | 0 psum0 :
143 |      {auto _ : SolvableRing a}
144 |   -> {x,y,z : a}
145 |   -> x === y
146 |   -> x === 0 * z + y
147 | psum0 prf = Calc $
148 |   |~ x
149 |   ~~ y          ... prf
150 |   ~~ 0 + y      ..< plusZeroLeftNeutral
151 |   ~~ 0 * z + y  ..< cong (+ y) multZeroLeftAbsorbs
152 |
153 | -- Multiplying a sum with a term preserves the evaluation result.
154 | 0 pmult1 :
155 |      {auto _ : SolvableRing a}
156 |   -> (m : a)
157 |   -> (p : Prod a as)
158 |   -> (s : Sum a as)
159 |   -> esum (mult1 (T m p) s) === (m * eprod p) * esum s
160 | pmult1 m p []            = sym multZeroRightAbsorbs
161 | pmult1 m p (T n q :: xs) = Calc $
162 |   |~ (m * n) * (eprod (mult p q)) + esum (mult1 (T m p) xs)
163 |   ~~ (m * n) * (eprod p * eprod q) + esum (mult1 (T m p) xs)
164 |      ... cong (\x => (m*n) * x + esum (mult1 (T m p) xs)) (pmult p q)
165 |   ~~ (m * eprod p) * (n * eprod q) + esum (mult1 (T m p) xs)
166 |      ..< cong (+ esum (mult1 (T m p) xs)) m1324
167 |   ~~ (m * eprod p) * (n * eprod q) + (m * eprod p) * esum xs
168 |      ... cong ((m * eprod p) * (n * eprod q) +) (pmult1 m p xs)
169 |   ~~ (m * eprod p) * (n * eprod q + esum xs)
170 |      ..< leftDistributive
171 |
172 | -- Multiplying two sums of terms preserves the evaluation result.
173 | 0 pmult :
174 |      {auto _ : SolvableRing a}
175 |   -> (x,y : Sum a as)
176 |   -> esum x * esum y === esum (mult x y)
177 | pmult []            y = multZeroLeftAbsorbs
178 | pmult (T n x :: xs) y = Calc $
179 |   |~ (n * eprod x + esum xs) * esum y
180 |   ~~ (n * eprod x) * esum y + esum xs * esum y
181 |      ... rightDistributive
182 |   ~~ (n * eprod x) * esum y + esum (mult xs y)
183 |      ... cong ((n * eprod x) * esum y +) (pmult xs y)
184 |   ~~ esum (mult1 (T n x) y) + esum (mult xs y)
185 |      ..< cong (+ esum (mult xs y)) (pmult1 n x y)
186 |   ~~ esum (add (mult1 (T n x) y) (mult xs y))
187 |      ... padd (mult1 (T n x) y) (mult xs y)
188 |
189 | -- Evaluating a negated term is equivalent to negate the
190 | -- result of evaluating the term.
191 | 0 pnegTerm :
192 |      {auto _ : SolvableRing a}
193 |   -> (x : Term a as)
194 |   -> eterm (negTerm x) === neg (eterm x)
195 | pnegTerm (T f p) = multNegLeft
196 |
197 | -- Evaluating a negated sum of terms is equivalent to negate the
198 | -- result of evaluating the sum of terms.
199 | 0 pneg :
200 |      {auto _ : SolvableRing a}
201 |   -> (x : Sum a as)
202 |   -> esum (negate x) === neg (esum x)
203 | pneg []       = sym $ negZero
204 | pneg (x :: y) = Calc $
205 |   |~ eterm (negTerm x) + esum (negate y)
206 |   ~~ neg (eterm x) + esum (negate y) ... cong (+ esum (negate y)) (pnegTerm x)
207 |   ~~ neg (eterm x) + neg (esum y)    ... cong (neg (eterm x) +) (pneg y)
208 |   ~~ neg (eterm x + esum y)          ..< negDistributes
209 |
210 | -- Removing zero values from a sum of terms does not
211 | -- affect the evaluation result.
212 | 0 pnormSum :
213 |      {auto _ : SolvableRing a}
214 |   -> (s : Sum a as)
215 |   -> esum (normSum s) === esum s
216 | pnormSum []           = Refl
217 | pnormSum (T f p :: y) with (isZero f)
218 |   _ | Nothing   = Calc $
219 |     |~ f * eprod p + esum (normSum y)
220 |     ~~ f * eprod p + esum y ... cong ((f * eprod p) +) (pnormSum y)
221 |
222 |   _ | Just refl = Calc $
223 |     |~ esum (normSum y)
224 |     ~~ esum y               ... pnormSum y
225 |     ~~ 0 + esum y           ..< plusZeroLeftNeutral
226 |     ~~ 0 * eprod p + esum y ..< cong (+ esum y) multZeroLeftAbsorbs
227 |     ~~ f * eprod p + esum y ..< cong (\x => x * eprod p + esum y) refl
228 |
229 | -- Evaluating an expression gives the same result as
230 | -- evaluating its normalized form.
231 | 0 pnorm :
232 |      {auto _ : SolvableRing a}
233 |   -> (e : Expr a as)
234 |   -> eval e === esum (norm e)
235 | pnorm (Lit n)    = Calc $
236 |   |~ n
237 |   ~~ n * 1                    ..< multOneRightNeutral
238 |   ~~ n * eprod (one {as})     ..< cong (n *) (pone as)
239 |   ~~ n * eprod (one {as}) + 0 ..< plusZeroRightNeutral
240 |
241 | pnorm (Var x y)  = Calc $
242 |   |~ x
243 |   ~~ eprod (fromVar y)          ..< pvar as y
244 |   ~~ 1 * eprod (fromVar y)      ..< multOneLeftNeutral
245 |   ~~ 1 * eprod (fromVar y) + 0  ..< plusZeroRightNeutral
246 |
247 | pnorm (Neg x) = Calc $
248 |   |~ neg (eval x)
249 |   ~~ neg (esum (norm x))    ... cong neg (pnorm x)
250 |   ~~ esum (negate (norm x)) ..< pneg (norm x)
251 |
252 | pnorm (Plus x y) = Calc $
253 |   |~ eval x + eval y
254 |   ~~ esum (norm x) + eval y
255 |      ... cong (+ eval y) (pnorm x)
256 |   ~~ esum (norm x) + esum (norm y)
257 |      ... cong (esum (norm x) +) (pnorm y)
258 |   ~~ esum (add (norm x) (norm y))
259 |      ... padd _ _
260 |
261 | pnorm (Mult x y) = Calc $
262 |   |~ eval x * eval y
263 |   ~~ esum (norm x) * eval y
264 |      ... cong (* eval y) (pnorm x)
265 |   ~~ esum (norm x) * esum (norm y)
266 |      ... cong (esum (norm x) *) (pnorm y)
267 |   ~~ esum (mult (norm x) (norm y))
268 |      ... Sum.pmult _ _
269 |
270 | pnorm (Minus x y) = Calc $
271 |   |~ eval x - eval y
272 |   ~~ eval x + neg (eval y)
273 |      ... minusIsPlusNeg
274 |   ~~ esum (norm x) + neg (eval y)
275 |      ... cong (+ neg (eval y)) (pnorm x)
276 |   ~~ esum (norm x) + neg (esum (norm y))
277 |      ... cong (\v => esum (norm x) + neg v) (pnorm y)
278 |   ~~ esum (norm x) + esum (negate (norm y))
279 |      ..< cong (esum (norm x) +) (pneg (norm y))
280 |   ~~ esum (add (norm x) (negate (norm y)))
281 |      ... padd _ _
282 |
283 | -- Evaluating an expression gives the same result as
284 | -- evaluating its normalized form.
285 | 0 pnormalize :
286 |      {auto _ : SolvableRing a}
287 |   -> (e : Expr a as)
288 |   -> eval e === esum (normalize e)
289 | pnormalize e = Calc $
290 |   |~ eval e
291 |   ~~ esum (norm e)           ... pnorm e
292 |   ~~ esum (normSum (norm e)) ..< pnormSum (norm e)
293 |
294 | --------------------------------------------------------------------------------
295 | --          Solver
296 | --------------------------------------------------------------------------------
297 |
298 | ||| Given a list `as` of variables and two arithmetic expressions
299 | ||| over these variables, if the expressions are converted
300 | ||| to the same normal form, evaluating them gives the same
301 | ||| result.
302 | |||
303 | ||| This simple fact allows us to conveniently and quickly
304 | ||| proof arithmetic equalities required in other parts of
305 | ||| our code. For instance:
306 | |||
307 | ||| ```idris example
308 | ||| 0 binom1 : {x,y : Bits8}
309 | |||          -> (x + y) * (x + y) === x * x + 2 * x * y + y * y
310 | ||| binom1 = solve [x,y]
311 | |||                ((x .+. y) * (x .+. y))
312 | |||                (x .*. x + 2 *. x *. y + y .*. y)
313 | ||| ```
314 | export
315 | 0 solve :
316 |      {auto _ : SolvableRing a}
317 |   -> (as : List a)
318 |   -> (e1,e2 : Expr a as)
319 |   -> {auto prf : normalize e1 === normalize e2}
320 |   -> eval e1 === eval e2
321 | solve _ e1 e2 = Calc $
322 |   |~ eval e1
323 |   ~~ esum (normalize e1) ...(pnormalize e1)
324 |   ~~ esum (normalize e2) ...(cong esum prf)
325 |   ~~ eval e2             ..<(pnormalize e2)
326 |
327 | --------------------------------------------------------------------------------
328 | --          Examples
329 | --------------------------------------------------------------------------------
330 |
331 | 0 binom1 : {x,y : Bits8} -> (x + y) * (x + y) === x * x + 2 * x * y + y * y
332 | binom1 =
333 |   solve
334 |     [x,y]
335 |     ((x .+. y) * (x .+. y))
336 |     (x .*. x + 2 *. x *. y + y .*. y)
337 |
338 | 0 binom2 : {x,y : Bits8} -> (x - y) * (x - y) === x * x - 2 * x * y + y * y
339 | binom2 =
340 |   solve
341 |     [x,y]
342 |     ((x .-. y) * (x .-. y))
343 |     (x .*. x - 2 *. x *. y + y .*. y)
344 |
345 | 0 binom3 : {x,y : Bits8} -> (x + y) * (x - y) === x * x - y * y
346 | binom3 = solve [x,y] ((x .+. y) * (x .-. y)) (x .*. x - y .*. y)
347 |