0 | module Data.Container.Additive.Morphism.Instances
  1 |
  2 | import Data.Vect
  3 | import Data.List.Quantifiers
  4 | import Data.Vect.Quantifiers
  5 |
  6 | import Data.Container.Base
  7 | import Data.ComMonoid
  8 | import Data.Num
  9 | import Data.Container.Additive.Object.Definition
 10 | import Data.Container.Additive.Object.Instances
 11 | import Data.Container.Additive.Morphism.Definition
 12 | import Data.Container.Additive.Product.Definitions
 13 |
 14 | import Data.Container.Additive.Quantifiers
 15 |
 16 | import Control.Monad.Distribution
 17 | import Control.Monad.Sample.Definition
 18 |
 19 | import Misc
 20 |
 21 | %hide Data.Container.Base.Object.Definition.Const
 22 | %hide Data.Vect.Quantifiers.All.index
 23 |
 24 | public export
 25 | toState : {0 c : AddCont} -> (x : c.Shp) -> Scalar =%> c
 26 | toState x = !%+ \() => (x ** \_ => ())
 27 |
 28 | public export
 29 | fromState : {0 c : AddCont} -> Scalar =%> c -> c.Shp
 30 | fromState f = f.fwd ()
 31 |
 32 | public export
 33 | toCostate : {0 c : AddCont} ->
 34 |   (s : (x : c.Shp) -> c.Pos x) ->
 35 |   c =%> Scalar
 36 | toCostate s = !%+ \x => (() ** \() => s x)
 37 |
 38 | public export
 39 | fromCostate : {0 c : AddCont} ->
 40 |   c =%> Scalar ->
 41 |   (x : c.Shp) -> c.Pos x
 42 | fromCostate f x = f.bwd x ()
 43 |
 44 |
 45 | public export
 46 | pushDown : AddCont -> AddCont
 47 | pushDown d = !! (Const2 Unit d.Shp)
 48 |
 49 | public export
 50 | pushIntoContinuationList : {d, p, l : AddCont} ->
 51 |   (f : d >< p =%> l) ->
 52 |   (p =%> (pushDown d) >@ (List l))
 53 | pushIntoContinuationList f = !%+ \param => (() <|
 54 |   \ds => map (\dShp => f.fwd (dShp, param)) ds **
 55 |     \ll => sum @{UMon p param} (ll >>=
 56 |       \(ds ** grads=> extractPGrads param ds grads))
 57 |   where
 58 |     extractPGrads : (param : p.Shp) ->
 59 |       (ds : List d.Shp) ->
 60 |       All l.Pos ((\dShp => f.fwd (dShp, param)) <$> ds) ->
 61 |       List (p.Pos param)
 62 |     extractPGrads param [] [] = []
 63 |     extractPGrads param (dShp :: ds) (grad :: grads) =
 64 |       snd (f.bwd (dShp, param) grad) :: extractPGrads param ds grads
 65 |
 66 | public export
 67 | pushIntoContinuation : {p : AddCont} -> (flat : IsFlat l) => Num l.Shp =>
 68 |   (f : d >< p =%> l) ->
 69 |   (p =%> (pushDown d) >@ l)
 70 | pushIntoContinuation {flat = MkIsFlat _} f = !%+ \param => (() <|
 71 |   \ds => sum @{numIsMonoid} ((\dShp => f.fwd (dShp, param)) <$> ds) **
 72 |     \ll => sum @{UMon p param} (ll >>=
 73 |       \(ds ** grad=> (\dShp => snd (f.bwd (dShp, param) grad)) <$> ds))
 74 |
 75 | public export
 76 | constantOne : {c : AddCont} ->
 77 |   InterfaceOnPositions c Num =>
 78 |   c =%> Scalar
 79 | constantOne @{MkI @{p}} = toCostate {c=c} (\x => let numPos = p x in 1)
 80 |
 81 | namespace HancockTensorProduct
 82 |   public export
 83 |   leftUnit : {c : AddCont} -> (Scalar >< c) =%> c
 84 |   leftUnit = !%+ \((), s) => (s ** \p => ((), p))
 85 |   
 86 |   public export
 87 |   rightUnit : {c : AddCont} -> (c >< Scalar) =%> c
 88 |   rightUnit = !%+ \(x, ()) => (x ** \x' => (x', ()))
 89 |
 90 |   public export
 91 |   leftUnitInv : {c : AddCont} -> c =%> (Scalar >< c)
 92 |   leftUnitInv = !%+ \x => (((), x) ** \((), x') => x')
 93 |   
 94 |   public export
 95 |   rightUnitInv : {c : AddCont} -> c =%> (c >< Scalar)
 96 |   rightUnitInv = !%+ \x => ((x, ()) ** \(x', ()) => x')
 97 |
 98 |   public export
 99 |   assocL : {0 a, b, c : AddCont} -> ((a >< b) >< c) =%> (a >< (b >< c))
100 |   assocL = !%+ \((a, b), c) => ((a, (b, c)) ** \(a', (b', c')) => ((a', b'), c'))
101 |
102 |   public export
103 |   assocR : {0 a, b, c : AddCont} -> (a >< (b >< c)) =%> ((a >< b) >< c)
104 |   assocR = !%+ \(a, (b, c)) => (((a, b), c) ** \((a', b'), c') => (a', (b', c')))
105 |
106 |   public export
107 |   swap : {0 a, b : AddCont} -> (a >< b) =%> (b >< a)
108 |   swap = !%+ \(a, b) => ((b, a) ** \(b', a') => (a', b'))
109 |
110 | namespace CompositionProduct
111 |   public export
112 |   leftUnit : {0 c : AddCont} -> (Scalar >@ c) =%> c
113 |   leftUnit = !%+ \(() <| cShp) => (cShp () ** \c' => [(() ** c')])
114 |
115 |   public export
116 |   rightUnit : {0 c : AddCont} -> (c >@ Scalar) =%> c
117 |   rightUnit = !%+ \(s <| _) => (s ** \p => [(p ** ())])
118 |
119 |   public export
120 |   leftUnitInv : {c : AddCont} -> c =%> (Scalar >@ c)
121 |   leftUnitInv = !%+ \s => (() <| (\_ => s) ** \ll =>
122 |     sum @{UMon c s} (snd <$> ll))
123 |   
124 |   ||| Right unit inverse: c =%> c >@ I
125 |   public export
126 |   rightUnitInv : {c : AddCont} -> c =%> (c >@ Scalar)
127 |   rightUnitInv = !%+ \s => (s <| const () ** \ll =>
128 |     sum @{UMon c s} (fst <$> ll))
129 |
130 |
131 | namespace Coproduct
132 |   public export
133 |   elim : {c : AddCont} ->
134 |     (c >+< c) =%> c
135 |   elim = !%+ \case
136 |     (Left x) => (x ** id)
137 |     (Right y) => (y ** id)
138 |
139 | public export
140 | duoidal : {c, d, e, f : AddCont} ->
141 |   ((c >@ d) >< (e >@ f)) =%> ((c >< e) >@ (d >< f))
142 | duoidal = !%+ \((sc <| idxC), (se <| idxE)) =>
143 |   ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) **
144 |     \ll => ((\((cp, ep) ** (dp, fp)) => (cp ** dp)) <$> ll,
145 |             (\((cp, ep) ** (dp, fp)) => (ep ** fp)) <$> ll))
146 |
147 | public export
148 | coprodDistrOverTensor : {a, b, p, q : AddCont} ->
149 |   ((a >+< b) >< (p >< q)) =%> ((a >< p) >+< (b >< q))
150 | coprodDistrOverTensor = !%+ \case
151 |   (Left a, (p, _)) => (Left (a, p) ** \(a', p') => (a', (p', q.Zero _)))
152 |   (Right b, (_, q)) => (Right (b, q) ** \(b', q') => (b', (p.Zero _, q')))
153 |
154 | ||| Not an isomorphism, arising from duoidal structure between >@ and ><
155 | public export
156 | rebracketcomptensor: {e, y : AddCont} -> ((e >@ y) >< y) =%> (e >@ (y >< y))
157 | rebracketcomptensor = (id {c=e >@ y} >< leftUnitInv {c=y})
158 |                       %>> duoidal {c=e} {d=y} {e=Scalar} {f=y}
159 |                       %>> (rightUnit {c=e} >@ id {c=(y><y)})
160 |
161 |
162 | public export
163 | distribute : {c, e, g : AddCont} ->
164 |   ((c >< e) =%> s) ->
165 |   ((c >< (e >@ g)) =%> (s >@ g))
166 | distribute f = (rightUnitInv >< id {c=e >@ g})
167 |              %>> duoidal {d = Scalar}
168 |              %>> (f >@ leftUnit)
169 |
170 | public export
171 | extractEffect : {d, e, f : AddCont} ->
172 |   (d >< (e >@ f)) =%> (e >@ (d >< f))
173 | extractEffect = (leftUnitInv >< (id {c=e >@ f}))
174 |             %>> duoidal {c=Scalar}
175 |             %>> (leftUnit >@ (id {c=d><f}))
176 |
177 |   
178 | public export
179 | swapMiddle : {c1, c2, c3, c4 : AddCont} ->
180 |   ((c1 >< c2) >< (c3 >< c4)) =%> ((c1 >< c3) >< (c2 >< c4))
181 | swapMiddle = !%+ \((x, y), (z, w)) => (((x, z), (y, w)) **
182 |   \((x', z'), (y', w')) => ((x', y'), (z', w')))
183 |
184 | public export
185 | Copy : {c : AddCont} ->
186 |   c =%> (c >< c)
187 | Copy = !%+ \x => ((x, x) ** uncurry (c.Plus x))
188 |
189 | public export
190 | PairMaps : {c, d, e : AddCont} ->
191 |   (f : c =%> d) ->
192 |   (g : c =%> e) ->
193 |   c =%> (d >< e)
194 | PairMaps f g = Copy %>> (f >< g)
195 |
196 | public export
197 | Delete : {c : AddCont} ->
198 |   c =%> Scalar
199 | Delete = !%+ \x => (() ** \() => c.Zero x)
200 |
201 | ||| Note that this doesn't exist for ordinary containers!
202 | public export
203 | ProjLeft : {c, d : AddCont} ->
204 |   (c >< d) =%> c
205 | ProjLeft = !%+ \(x, y) => (x ** \x' => (x', d.Zero y))
206 |
207 | public export
208 | ProjRight : {c, d : AddCont} ->
209 |   (c >< d) =%> d
210 | ProjRight = !%+ \(x, y) => (y ** \y' => (c.Zero x, y'))
211 |
212 | public export
213 | Sum : Num a =>
214 |   (Const a >< Const a) =%> Const a
215 | Sum = !%+ \(x1, x2) => (x1 + x2 ** \x' => (x', x'))
216 |
217 | public export
218 | bwSumList : {l : Type} -> ComMonoid l =>
219 |   (xs : List l) ->
220 |   (d' : l) ->
221 |   All (const l) xs
222 | bwSumList [] d' = []
223 | bwSumList (x :: xs) d' = x :: bwSumList xs x
224 |
225 |
226 | public export
227 | SumList : {l : Type} -> ComMonoid l =>
228 |   List (Const l) =%> Const l
229 | SumList = !%+ \xs => (sum xs ** \d' => bwSumList xs d')
230 |
231 | public export
232 | Negate : Num a => Neg a =>
233 |   Const a =%> Const a
234 | Negate = !%+ \x => (-x ** \x' => -x')
235 |
236 | public export
237 | Zero : {c : AddCont} -> Num a =>
238 |   c =%> Const a
239 | Zero = !%+ \_ => (0 ** \_ => c.Zero _)
240 |
241 | public export
242 | Mul : Num a =>
243 |   (Const a >< Const a) =%> Const a
244 | Mul = !%+ \(x1, x2) => (x1 * x2 ** \x' => (x' * x2, x' * x1))
245 |
246 | ||| Mean squared error
247 | public export
248 | SquaredDifference : {a : Type} -> Num a => Neg a => (Const a >< Const a) =%> (Const a)
249 | SquaredDifference = ((id {c=Const a}) >< Negate) %>> Sum %>> Copy %>> Mul
250 |
251 | namespace Sample
252 |   ||| Select a shape from All to produce an Any at the given index
253 |   ||| Same as `index i (allAnies shapes)` but reduces better
254 |   public export
255 |   selectShape : {cs : Vect k AddCont} ->
256 |     (shapes : All (.Shp) cs) -> (i : Fin k) -> Any (.Shp) cs
257 |   selectShape (s :: ss) FZ = Here s
258 |   selectShape (s :: ss) (FS j) = There (selectShape ss j)
259 |
260 |   ||| Extract the position from an AnyPos at a given index
261 |   public export
262 |   extractPos : {n : Nat} -> {xs : Vect n AddCont} -> {shapes : All (.Shp) xs} ->
263 |     (i : Fin n) ->
264 |     AnyShpPos (selectShape shapes i) ->
265 |     (index i xs).Pos (index i shapes)
266 |   extractPos {shapes = (_ :: _)} FZ (Here x') = x'
267 |   extractPos {shapes = (_ :: _)} (FS j) (There rest) = extractPos j rest
268 |
269 | -- parameters (f : Type -> Type)
270 | --   ||| These are all of the morphisms in the cokleisli category of (f <!> -)  
271 | --   public export
272 | --   MonLens : Cont -> Cont -> Type
273 | --   MonLens c d = (f <!> c) =%> d
274 | -- 
275 | --   public export
276 | --   counit : Monad f => f <!> c =%> c
277 | --   counit = !% \x => (x ** pure)
278 | -- 
279 | --   public export
280 | --   cojoin : Monad f => (f <!> c) =%> (f <!> (f <!> c))
281 | --   cojoin = !% \x => (x ** join)
282 |
283 |   
284 | public export
285 | record FCoAlgCont (f : Type -> Type) where
286 |   constructor MkFCoAlgCont
287 |   carrier : Cont
288 |   coalg : (: carrier.Shp) -> f (carrier.Pos a) -> carrier.Pos a
289 |
290 | public export
291 | coAlgMorphism : (c, d : FCoAlgCont f) -> Type
292 | coAlgMorphism c d = c.carrier =%> d.carrier
293 |
294 | convert : FCoAlgCont List -> AddCont
295 | convert (MkFCoAlgCont carrier coalg) = MkAddCont
296 |   carrier
297 |   {mon=(MkI @{\s => MkComMonoid
298 |     (\l, r => coalg s [l, r])
299 |     (coalg s [])})}