0 | module Data.Container.Additive.Morphism.Instances
  1 |
  2 | import Data.Vect
  3 | import Data.List.Quantifiers
  4 | import Data.Vect.Quantifiers
  5 |
  6 | import Data.Container.Base
  7 | import Data.ComMonoid
  8 | import Data.Num
  9 | import Data.Container.Additive.Object.Definition
 10 | import Data.Container.Additive.Object.Instances
 11 | import Data.Container.Additive.Morphism.Definition
 12 | import Data.Container.Additive.Product.Definitions
 13 | import Data.Container.Additive.Properties.Definitions
 14 |
 15 | import Data.Container.Additive.Quantifiers
 16 |
 17 | import Control.Monad.Distribution
 18 | import Control.Monad.Sample.Definition
 19 |
 20 | import Misc
 21 |
 22 | %hide Base.Object.Instances.Const
 23 | %hide Data.Vect.Quantifiers.All.index
 24 | %hide Base.Morphism.Definition.DependentLenses.(=%>)
 25 | %hide Base.Morphism.Instances.State.State
 26 | %hide Base.Morphism.Instances.Costate.Costate
 27 | %hide Base.Product.Definitions.HancockTensorProduct.(><)
 28 |
 29 | namespace State
 30 |   ||| "State" as defined in https://arxiv.org/abs/2403.13001 and open games 
 31 |   |||
 32 |   |||       ┌─────────────┐
 33 |   |||       │             ├──► (x : c.Shp)
 34 |   |||       │    State    │
 35 |   |||       │             ├◄── c.Pos x
 36 |   |||       └─────────────┘
 37 |   public export
 38 |   State : AddCont -> Type
 39 |   State c = Scalar =%> c
 40 |
 41 |   public export
 42 |   toState : (x : c.Shp) -> State c
 43 |   toState x = !% toState x
 44 |   
 45 |   public export
 46 |   fromState : State c -> c.Shp
 47 |   fromState f = f.fwd ()
 48 |
 49 | namespace Costate
 50 |   ||| "Costate" as defined in https://arxiv.org/abs/2403.13001 and open games 
 51 |   |||
 52 |   |||                  ┌─────────────┐
 53 |   |||  (x : c.Shp)  ──►┤             │
 54 |   |||                  │   Costate   │
 55 |   |||     c.Pos x   ◄──┤             │
 56 |   |||                  └─────────────┘
 57 |   public export
 58 |   Costate : AddCont -> Type
 59 |   Costate c = c =%> Scalar
 60 |   
 61 |   public export
 62 |   toCostate : ((x : c.Shp) -> c.Pos x) -> Costate c
 63 |   toCostate s = !% toCostate s
 64 |   
 65 |   public export
 66 |   fromCostate : Costate c -> (x : c.Shp) -> c.Pos x
 67 |   fromCostate f x = f.bwd x ()
 68 |
 69 |   public export
 70 |   constantOne : InterfaceOnPositions c Num => Costate c
 71 |   constantOne @{MkI p} = toCostate (\x => let numPos = p x in 1)
 72 |
 73 |   public export
 74 |   Delete : {c : AddCont} -> Costate c 
 75 |   Delete = toCostate c.Zero
 76 |   
 77 |
 78 | ||| If we model the idea of a container (S !> P) as a box
 79 | |||  ┌──────┐
 80 | |||  │ s:S  │
 81 | |||  ├──────┤
 82 | |||  │  Ps  │
 83 | |||  └──────┘
 84 | ||| then `pushDown` is interpreted as pushing down the container,
 85 | ||| pruning anything that goes out of the box, and using `Unit` for
 86 | ||| anything new that appears:
 87 | |||  ┌──────┐
 88 | |||  │ Unit │
 89 | |||  ├──────┤
 90 | |||  │ s:S  │
 91 | |||  └──────┘
 92 | |||     Ps
 93 | ||| For additive containers we need to take the free monoid
 94 | public export
 95 | pushDown : AddCont -> AddCont
 96 | pushDown c = !! pushDown (UC c)
 97 |
 98 | public export
 99 | pushIntoContinuationList : {p : AddCont} -> {0 d, l : AddCont} ->
100 |   d >< p =%> l ->
101 |   p =%> (pushDown d) >@ (List l)
102 | pushIntoContinuationList f = !%+ \param => (() <|
103 |   \ds => ds <&> (\dShp => f.fwd (dShp, param)) **
104 |     \ll => sum @{UMon p param} (ll >>=
105 |       \(ds ** grads=> extractPGrads param ds grads))
106 |   where
107 |     extractPGrads : (param : p.Shp) ->
108 |       (ds : List d.Shp) ->
109 |       All l.Pos ((\dShp => f.fwd (dShp, param)) <$> ds) ->
110 |       List (p.Pos param)
111 |     extractPGrads param [] [] = []
112 |     extractPGrads param (dShp :: ds) (grad :: grads) =
113 |       snd (f.bwd (dShp, param) grad) :: extractPGrads param ds grads
114 |
115 | public export
116 | pushIntoContinuation : {p : AddCont} -> (flat : IsFlat l) => Num l.Shp =>
117 |   (f : d >< p =%> l) ->
118 |   (p =%> (pushDown d) >@ l)
119 | pushIntoContinuation {flat = MkIsFlat _} f = !%+ \param => (() <|
120 |   \ds => sum @{numIsMonoid} ((\dShp => f.fwd (dShp, param)) <$> ds) **
121 |     \ll => sum @{UMon p param} (ll >>=
122 |       \(ds ** grad=> (\dShp => snd (f.bwd (dShp, param) grad)) <$> ds))
123 |
124 | ||| This is also the categorical product since our containers are additive
125 | namespace HancockTensorProduct
126 |   public export
127 |   leftUnit : Scalar >< c =%> c
128 |   leftUnit = !% leftUnit
129 |   
130 |   public export
131 |   rightUnit : c >< Scalar =%> c
132 |   rightUnit = !% rightUnit
133 |
134 |   public export
135 |   leftUnitInv : c =%> Scalar >< c
136 |   leftUnitInv = !% leftUnitInv
137 |   
138 |   public export
139 |   rightUnitInv : c =%> c >< Scalar
140 |   rightUnitInv = !% rightUnitInv
141 |
142 |   public export
143 |   assocL : (a >< b) >< c =%> a >< (b >< c)
144 |   assocL = !% assocL
145 |
146 |   public export
147 |   assocR : a >< (b >< c) =%> (a >< b) >< c
148 |   assocR = !% assocR
149 |
150 |   public export
151 |   swap : a >< b =%> b >< a
152 |   swap = !% swap
153 |
154 |   public export
155 |   swapMiddle : (c1 >< c2) >< (c3 >< c4) =%> (c1 >< c3) >< (c2 >< c4)
156 |   swapMiddle = !% swapMiddle
157 |
158 |   ||| These do not exist for ordinary containers!
159 |   ||| Here we need `c` not to be erased since we're using its monoid structure
160 |   public export
161 |   Copy : {c : AddCont} -> c =%> c >< c
162 |   Copy = !%+ \x => ((x, x) ** uncurry (c.Plus x))
163 |   
164 |   public export
165 |   PairMaps : {c : AddCont} ->
166 |     c =%> d ->
167 |     c =%> e ->
168 |     c =%> d >< e
169 |   PairMaps f g = Copy %>> (f >< g)
170 |   
171 |   public export
172 |   ProjLeft : {d : AddCont} -> c >< d =%> c
173 |   ProjLeft = !%+ \(x, y) => (x ** \x' => (x', d.Zero y))
174 |   
175 |   public export
176 |   ProjRight : {c : AddCont} -> c >< d =%> d
177 |   ProjRight = !%+ \(x, y) => (y ** \y' => (c.Zero x, y'))
178 |
179 |
180 | namespace CompositionProduct
181 |   public export
182 |   leftUnit : Scalar >@ c =%> c
183 |   leftUnit = !% pureBw %>> leftUnit
184 |
185 |   public export
186 |   rightUnit : c >@ Scalar =%> c
187 |   rightUnit = !% pureBw %>> rightUnit
188 |
189 |   public export
190 |   leftUnitInv : {c : AddCont} -> c =%> Scalar >@ c
191 |   leftUnitInv = !%+ \s => (() <| (\_ => s) ** \ll => 
192 |     sum @{UMon c s} (snd <$> ll))
193 |   -- leftUnitInv {c=MkAddCont uc} = (!% CompositionProduct.leftUnitInv) %>> ?eiei
194 |   
195 |   ||| Right unit inverse: c =%> c >@ I
196 |   public export
197 |   rightUnitInv : {c : AddCont} -> c =%> (c >@ Scalar)
198 |   rightUnitInv = !%+ \s => (s <| const () ** \ll =>
199 |     sum @{UMon c s} (fst <$> ll))
200 |
201 |
202 | namespace Coproduct
203 |   public export
204 |   elim : c >+< c =%> c
205 |   elim = !% elim
206 |
207 | public export
208 | duoidal : (c >@ d) >< (e >@ f) =%> (c >< e) >@ (d >< f)
209 | duoidal = !%+ \((sc <| idxC), (se <| idxE)) =>
210 |   ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) **
211 |     \ll => ((\((cp, ep) ** (dp, fp)) => (cp ** dp)) <$> ll,
212 |             (\((cp, ep) ** (dp, fp)) => (ep ** fp)) <$> ll))
213 |
214 | public export
215 | coprodDistrOverTensor : {q, p : AddCont} ->
216 |   (a >+< b) >< (p >< q) =%> (a >< p) >+< (b >< q)
217 | coprodDistrOverTensor = !%+ \case
218 |   (Left a, (p, _)) => (Left (a, p) ** \(a', p') => (a', (p', q.Zero _)))
219 |   (Right b, (_, q)) => (Right (b, q) ** \(b', q') => (b', (p.Zero _, q')))
220 |
221 | ||| Not an isomorphism, arising from duoidal structure between >@ and ><
222 | public export
223 | rebracketcomptensor: {y : AddCont} -> (e >@ y) >< y =%> e >@ (y >< y)
224 | rebracketcomptensor = (id {c=e >@ y} >< leftUnitInv {c=y})
225 |                       %>> duoidal {c=e} {d=y} {e=Scalar} {f=y}
226 |                       %>> (rightUnit {c=e} >@ id {c=(y><y)})
227 |
228 |
229 | public export
230 | distribute : {c : AddCont} ->
231 |   c >< e =%> s ->
232 |   c >< (e >@ g) =%> s >@ g
233 | distribute f = (rightUnitInv >< id {c=e >@ g})
234 |              %>> duoidal {d = Scalar}
235 |              %>> (f >@ leftUnit)
236 |
237 | public export
238 | extractEffect : {d : AddCont} ->
239 |   d >< (e >@ f) =%> e >@ (d >< f)
240 | extractEffect = (leftUnitInv >< (id {c=e >@ f}))
241 |             %>> duoidal {c=Scalar}
242 |             %>> (leftUnit >@ (id {c=d><f}))
243 |
244 |   
245 | public export
246 | Sum : Num a =>
247 |   (Const a >< Const a) =%> Const a
248 | Sum = !%+ \(x1, x2) => (x1 + x2 ** \x' => (x', x'))
249 |
250 | public export
251 | bwSumList : {l : Type} -> ComMonoid l =>
252 |   (xs : List l) ->
253 |   (d' : l) ->
254 |   All (const l) xs
255 | bwSumList [] d' = []
256 | bwSumList (x :: xs) d' = x :: bwSumList xs x
257 |
258 |
259 | public export
260 | SumList : {l : Type} -> ComMonoid l =>
261 |   List (Const l) =%> Const l
262 | SumList = !%+ \xs => (sum xs ** \d' => bwSumList xs d')
263 |
264 | public export
265 | Negate : Num a => Neg a =>
266 |   Const a =%> Const a
267 | Negate = !%+ \x => (-x ** \x' => -x')
268 |
269 | public export
270 | Zero : {c : AddCont} -> Num a =>
271 |   c =%> Const a
272 | Zero = !%+ \_ => (0 ** \_ => c.Zero _)
273 |
274 | public export
275 | Mul : Num a =>
276 |   (Const a >< Const a) =%> Const a
277 | Mul = !%+ \(x1, x2) => (x1 * x2 ** \x' => (x' * x2, x' * x1))
278 |
279 | ||| Mean squared error
280 | public export
281 | SquaredDifference : {a : Type} -> Num a => Neg a => (Const a >< Const a) =%> (Const a)
282 | SquaredDifference = ((id {c=Const a}) >< Negate) %>> Sum %>> Copy %>> Mul
283 |
284 | namespace Sample
285 |   ||| Select a shape from All to produce an Any at the given index
286 |   ||| Same as `index i (allAnies shapes)` but reduces better
287 |   public export
288 |   selectShape : {cs : Vect k AddCont} ->
289 |     (shapes : All (.Shp) cs) -> (i : Fin k) -> Any (.Shp) cs
290 |   selectShape (s :: ss) FZ = Here s
291 |   selectShape (s :: ss) (FS j) = There (selectShape ss j)
292 |
293 |   ||| Extract the position from an AnyPos at a given index
294 |   public export
295 |   extractPos : {n : Nat} -> {xs : Vect n AddCont} -> {shapes : All (.Shp) xs} ->
296 |     (i : Fin n) ->
297 |     AnyShpPos (selectShape shapes i) ->
298 |     (index i xs).Pos (index i shapes)
299 |   extractPos {shapes = (_ :: _)} FZ (Here x') = x'
300 |   extractPos {shapes = (_ :: _)} (FS j) (There rest) = extractPos j rest
301 |
302 | -- parameters (f : Type -> Type)
303 | --   ||| These are all of the morphisms in the cokleisli category of (f <!> -)  
304 | --   public export
305 | --   MonLens : Cont -> Cont -> Type
306 | --   MonLens c d = (f <!> c) =%> d
307 | -- 
308 | --   public export
309 | --   counit : Monad f => f <!> c =%> c
310 | --   counit = !% \x => (x ** pure)
311 | -- 
312 | --   public export
313 | --   cojoin : Monad f => (f <!> c) =%> (f <!> (f <!> c))
314 | --   cojoin = !% \x => (x ** join)
315 |
316 |   
317 | -- public export
318 | -- record FCoAlgCont (f : Type -> Type) where
319 | --   constructor MkFCoAlgCont
320 | --   carrier : Cont
321 | --   coalg : (a : carrier.Shp) -> f (carrier.Pos a) -> carrier.Pos a
322 |
323 | -- public export
324 | -- coAlgMorphism : (c, d : FCoAlgCont f) -> Type
325 | -- coAlgMorphism c d = c.carrier =%> d.carrier
326 | -- 
327 | -- convert : FCoAlgCont List -> AddCont
328 | -- convert (MkFCoAlgCont carrier coalg) = MkAddCont
329 | --   carrier
330 | --   {mon=(MkI $ \s => MkComMonoid
331 | --     (\l, r => coalg s [l, r])
332 | --     (coalg s []))}