0 | module Data.Container.Additive.Morphism.Instances
  1 |
  2 | import Data.Vect
  3 | import Data.List.Quantifiers
  4 | import Data.Vect.Quantifiers
  5 |
  6 | import Data.Container.Base
  7 | import Data.ComMonoid
  8 | import Data.Num
  9 | import Data.Container.Additive.Object.Definition
 10 | import Data.Container.Additive.Object.Instances
 11 | import Data.Container.Additive.Morphism.Definition
 12 | import Data.Container.Additive.Product.Definitions
 13 | import Data.Container.Additive.Properties.Definitions
 14 |
 15 | import Data.Container.Additive.Quantifiers
 16 |
 17 | import Control.Monad.Distribution
 18 | import Control.Monad.Sample.Definition
 19 |
 20 | import Misc
 21 |
 22 | %hide Data.Container.Base.Object.Instances.Const
 23 | %hide Data.Vect.Quantifiers.All.index
 24 |
 25 | public export
 26 | toState : {0 c : AddCont} -> (x : c.Shp) -> Scalar =%> c
 27 | toState x = !%+ \() => (x ** \_ => ())
 28 |
 29 | public export
 30 | fromState : {0 c : AddCont} -> Scalar =%> c -> c.Shp
 31 | fromState f = f.fwd ()
 32 |
 33 | public export
 34 | toCostate : {0 c : AddCont} ->
 35 |   (s : (x : c.Shp) -> c.Pos x) ->
 36 |   c =%> Scalar
 37 | toCostate s = !%+ \x => (() ** \() => s x)
 38 |
 39 | public export
 40 | fromCostate : {0 c : AddCont} ->
 41 |   c =%> Scalar ->
 42 |   (x : c.Shp) -> c.Pos x
 43 | fromCostate f x = f.bwd x ()
 44 |
 45 |
 46 | public export
 47 | pushDown : AddCont -> AddCont
 48 | pushDown d = !! (Const2 Unit d.Shp)
 49 |
 50 | public export
 51 | pushIntoContinuationList : {d, p, l : AddCont} ->
 52 |   (f : d >< p =%> l) ->
 53 |   (p =%> (pushDown d) >@ (List l))
 54 | pushIntoContinuationList f = !%+ \param => (() <|
 55 |   \ds => map (\dShp => f.fwd (dShp, param)) ds **
 56 |     \ll => sum @{UMon p param} (ll >>=
 57 |       \(ds ** grads=> extractPGrads param ds grads))
 58 |   where
 59 |     extractPGrads : (param : p.Shp) ->
 60 |       (ds : List d.Shp) ->
 61 |       All l.Pos ((\dShp => f.fwd (dShp, param)) <$> ds) ->
 62 |       List (p.Pos param)
 63 |     extractPGrads param [] [] = []
 64 |     extractPGrads param (dShp :: ds) (grad :: grads) =
 65 |       snd (f.bwd (dShp, param) grad) :: extractPGrads param ds grads
 66 |
 67 | public export
 68 | pushIntoContinuation : {p : AddCont} -> (flat : IsFlat l) => Num l.Shp =>
 69 |   (f : d >< p =%> l) ->
 70 |   (p =%> (pushDown d) >@ l)
 71 | pushIntoContinuation {flat = MkIsFlat _} f = !%+ \param => (() <|
 72 |   \ds => sum @{numIsMonoid} ((\dShp => f.fwd (dShp, param)) <$> ds) **
 73 |     \ll => sum @{UMon p param} (ll >>=
 74 |       \(ds ** grad=> (\dShp => snd (f.bwd (dShp, param) grad)) <$> ds))
 75 |
 76 | public export
 77 | constantOne : {c : AddCont} ->
 78 |   InterfaceOnPositions c Num =>
 79 |   c =%> Scalar
 80 | constantOne @{MkI @{p}} = toCostate {c=c} (\x => let numPos = p x in 1)
 81 |
 82 | namespace HancockTensorProduct
 83 |   public export
 84 |   leftUnit : {c : AddCont} -> (Scalar >< c) =%> c
 85 |   leftUnit = !%+ \((), s) => (s ** \p => ((), p))
 86 |   
 87 |   public export
 88 |   rightUnit : {c : AddCont} -> (c >< Scalar) =%> c
 89 |   rightUnit = !%+ \(x, ()) => (x ** \x' => (x', ()))
 90 |
 91 |   public export
 92 |   leftUnitInv : {c : AddCont} -> c =%> (Scalar >< c)
 93 |   leftUnitInv = !%+ \x => (((), x) ** \((), x') => x')
 94 |   
 95 |   public export
 96 |   rightUnitInv : {c : AddCont} -> c =%> (c >< Scalar)
 97 |   rightUnitInv = !%+ \x => ((x, ()) ** \(x', ()) => x')
 98 |
 99 |   public export
100 |   assocL : {0 a, b, c : AddCont} -> ((a >< b) >< c) =%> (a >< (b >< c))
101 |   assocL = !%+ \((a, b), c) => ((a, (b, c)) ** \(a', (b', c')) => ((a', b'), c'))
102 |
103 |   public export
104 |   assocR : {0 a, b, c : AddCont} -> (a >< (b >< c)) =%> ((a >< b) >< c)
105 |   assocR = !%+ \(a, (b, c)) => (((a, b), c) ** \((a', b'), c') => (a', (b', c')))
106 |
107 |   public export
108 |   swap : {0 a, b : AddCont} -> (a >< b) =%> (b >< a)
109 |   swap = !%+ \(a, b) => ((b, a) ** \(b', a') => (a', b'))
110 |
111 | namespace CompositionProduct
112 |   public export
113 |   leftUnit : {0 c : AddCont} -> (Scalar >@ c) =%> c
114 |   leftUnit = !%+ \(() <| cShp) => (cShp () ** \c' => [(() ** c')])
115 |
116 |   public export
117 |   rightUnit : {0 c : AddCont} -> (c >@ Scalar) =%> c
118 |   rightUnit = !%+ \(s <| _) => (s ** \p => [(p ** ())])
119 |
120 |   public export
121 |   leftUnitInv : {c : AddCont} -> c =%> (Scalar >@ c)
122 |   leftUnitInv = !%+ \s => (() <| (\_ => s) ** \ll =>
123 |     sum @{UMon c s} (snd <$> ll))
124 |   
125 |   ||| Right unit inverse: c =%> c >@ I
126 |   public export
127 |   rightUnitInv : {c : AddCont} -> c =%> (c >@ Scalar)
128 |   rightUnitInv = !%+ \s => (s <| const () ** \ll =>
129 |     sum @{UMon c s} (fst <$> ll))
130 |
131 |
132 | namespace Coproduct
133 |   public export
134 |   elim : {c : AddCont} ->
135 |     (c >+< c) =%> c
136 |   elim = !%+ \case
137 |     (Left x) => (x ** id)
138 |     (Right y) => (y ** id)
139 |
140 | public export
141 | duoidal : {c, d, e, f : AddCont} ->
142 |   ((c >@ d) >< (e >@ f)) =%> ((c >< e) >@ (d >< f))
143 | duoidal = !%+ \((sc <| idxC), (se <| idxE)) =>
144 |   ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) **
145 |     \ll => ((\((cp, ep) ** (dp, fp)) => (cp ** dp)) <$> ll,
146 |             (\((cp, ep) ** (dp, fp)) => (ep ** fp)) <$> ll))
147 |
148 | public export
149 | coprodDistrOverTensor : {a, b, p, q : AddCont} ->
150 |   ((a >+< b) >< (p >< q)) =%> ((a >< p) >+< (b >< q))
151 | coprodDistrOverTensor = !%+ \case
152 |   (Left a, (p, _)) => (Left (a, p) ** \(a', p') => (a', (p', q.Zero _)))
153 |   (Right b, (_, q)) => (Right (b, q) ** \(b', q') => (b', (p.Zero _, q')))
154 |
155 | ||| Not an isomorphism, arising from duoidal structure between >@ and ><
156 | public export
157 | rebracketcomptensor: {e, y : AddCont} -> ((e >@ y) >< y) =%> (e >@ (y >< y))
158 | rebracketcomptensor = (id {c=e >@ y} >< leftUnitInv {c=y})
159 |                       %>> duoidal {c=e} {d=y} {e=Scalar} {f=y}
160 |                       %>> (rightUnit {c=e} >@ id {c=(y><y)})
161 |
162 |
163 | public export
164 | distribute : {c, e, g : AddCont} ->
165 |   ((c >< e) =%> s) ->
166 |   ((c >< (e >@ g)) =%> (s >@ g))
167 | distribute f = (rightUnitInv >< id {c=e >@ g})
168 |              %>> duoidal {d = Scalar}
169 |              %>> (f >@ leftUnit)
170 |
171 | public export
172 | extractEffect : {d, e, f : AddCont} ->
173 |   (d >< (e >@ f)) =%> (e >@ (d >< f))
174 | extractEffect = (leftUnitInv >< (id {c=e >@ f}))
175 |             %>> duoidal {c=Scalar}
176 |             %>> (leftUnit >@ (id {c=d><f}))
177 |
178 |   
179 | public export
180 | swapMiddle : {c1, c2, c3, c4 : AddCont} ->
181 |   ((c1 >< c2) >< (c3 >< c4)) =%> ((c1 >< c3) >< (c2 >< c4))
182 | swapMiddle = !%+ \((x, y), (z, w)) => (((x, z), (y, w)) **
183 |   \((x', z'), (y', w')) => ((x', y'), (z', w')))
184 |
185 | public export
186 | Copy : {c : AddCont} ->
187 |   c =%> (c >< c)
188 | Copy = !%+ \x => ((x, x) ** uncurry (c.Plus x))
189 |
190 | public export
191 | PairMaps : {c, d, e : AddCont} ->
192 |   (f : c =%> d) ->
193 |   (g : c =%> e) ->
194 |   c =%> (d >< e)
195 | PairMaps f g = Copy %>> (f >< g)
196 |
197 | public export
198 | Delete : {c : AddCont} ->
199 |   c =%> Scalar
200 | Delete = !%+ \x => (() ** \() => c.Zero x)
201 |
202 | ||| Note that this doesn't exist for ordinary containers!
203 | public export
204 | ProjLeft : {c, d : AddCont} ->
205 |   (c >< d) =%> c
206 | ProjLeft = !%+ \(x, y) => (x ** \x' => (x', d.Zero y))
207 |
208 | public export
209 | ProjRight : {c, d : AddCont} ->
210 |   (c >< d) =%> d
211 | ProjRight = !%+ \(x, y) => (y ** \y' => (c.Zero x, y'))
212 |
213 | public export
214 | Sum : Num a =>
215 |   (Const a >< Const a) =%> Const a
216 | Sum = !%+ \(x1, x2) => (x1 + x2 ** \x' => (x', x'))
217 |
218 | public export
219 | bwSumList : {l : Type} -> ComMonoid l =>
220 |   (xs : List l) ->
221 |   (d' : l) ->
222 |   All (const l) xs
223 | bwSumList [] d' = []
224 | bwSumList (x :: xs) d' = x :: bwSumList xs x
225 |
226 |
227 | public export
228 | SumList : {l : Type} -> ComMonoid l =>
229 |   List (Const l) =%> Const l
230 | SumList = !%+ \xs => (sum xs ** \d' => bwSumList xs d')
231 |
232 | public export
233 | Negate : Num a => Neg a =>
234 |   Const a =%> Const a
235 | Negate = !%+ \x => (-x ** \x' => -x')
236 |
237 | public export
238 | Zero : {c : AddCont} -> Num a =>
239 |   c =%> Const a
240 | Zero = !%+ \_ => (0 ** \_ => c.Zero _)
241 |
242 | public export
243 | Mul : Num a =>
244 |   (Const a >< Const a) =%> Const a
245 | Mul = !%+ \(x1, x2) => (x1 * x2 ** \x' => (x' * x2, x' * x1))
246 |
247 | ||| Mean squared error
248 | public export
249 | SquaredDifference : {a : Type} -> Num a => Neg a => (Const a >< Const a) =%> (Const a)
250 | SquaredDifference = ((id {c=Const a}) >< Negate) %>> Sum %>> Copy %>> Mul
251 |
252 | namespace Sample
253 |   ||| Select a shape from All to produce an Any at the given index
254 |   ||| Same as `index i (allAnies shapes)` but reduces better
255 |   public export
256 |   selectShape : {cs : Vect k AddCont} ->
257 |     (shapes : All (.Shp) cs) -> (i : Fin k) -> Any (.Shp) cs
258 |   selectShape (s :: ss) FZ = Here s
259 |   selectShape (s :: ss) (FS j) = There (selectShape ss j)
260 |
261 |   ||| Extract the position from an AnyPos at a given index
262 |   public export
263 |   extractPos : {n : Nat} -> {xs : Vect n AddCont} -> {shapes : All (.Shp) xs} ->
264 |     (i : Fin n) ->
265 |     AnyShpPos (selectShape shapes i) ->
266 |     (index i xs).Pos (index i shapes)
267 |   extractPos {shapes = (_ :: _)} FZ (Here x') = x'
268 |   extractPos {shapes = (_ :: _)} (FS j) (There rest) = extractPos j rest
269 |
270 | -- parameters (f : Type -> Type)
271 | --   ||| These are all of the morphisms in the cokleisli category of (f <!> -)  
272 | --   public export
273 | --   MonLens : Cont -> Cont -> Type
274 | --   MonLens c d = (f <!> c) =%> d
275 | -- 
276 | --   public export
277 | --   counit : Monad f => f <!> c =%> c
278 | --   counit = !% \x => (x ** pure)
279 | -- 
280 | --   public export
281 | --   cojoin : Monad f => (f <!> c) =%> (f <!> (f <!> c))
282 | --   cojoin = !% \x => (x ** join)
283 |
284 |   
285 | public export
286 | record FCoAlgCont (f : Type -> Type) where
287 |   constructor MkFCoAlgCont
288 |   carrier : Cont
289 |   coalg : (: carrier.Shp) -> f (carrier.Pos a) -> carrier.Pos a
290 |
291 | public export
292 | coAlgMorphism : (c, d : FCoAlgCont f) -> Type
293 | coAlgMorphism c d = c.carrier =%> d.carrier
294 |
295 | convert : FCoAlgCont List -> AddCont
296 | convert (MkFCoAlgCont carrier coalg) = MkAddCont
297 |   carrier
298 |   {mon=(MkI $ \s => MkComMonoid
299 |     (\l, r => coalg s [l, r])
300 |     (coalg s []))}