0 | module Data.Container.Base.Morphism.Instances
  1 |
  2 | import Data.Fin
  3 | import Data.Fin.Split
  4 | import Data.Vect
  5 | import Data.List.Quantifiers
  6 |
  7 | import Data.Container.Base.Object.Definition
  8 | import Data.Container.Base.Morphism.Definition
  9 | import Data.Container.Base.Extension.Definition
 10 | import Data.Container.Base.Product.Definitions
 11 | import Data.Container.Base.Object.Instances
 12 |
 13 | import Data.Container.Base.Quantifiers
 14 | import Data.Container.Base.TreeUtils
 15 |
 16 | import Control.Monad.Distribution
 17 | import Control.Monad.Sample.Definition
 18 |
 19 | import Data.Num
 20 | import Data.Layout
 21 | import Misc
 22 |
 23 | ||| "State" as defined in https://arxiv.org/abs/2403.13001 and open games 
 24 | ||| Given a shape of any container, state can be defined
 25 | public export
 26 | State : Cont -> Type
 27 | State c = Scalar =%> c
 28 |
 29 | public export
 30 | Costate : Cont -> Type
 31 | Costate c = c =%> Scalar
 32 |
 33 | public export
 34 | toState : {0 c : Cont} -> (x : c.Shp) -> State c
 35 | toState x = !% \() => (x ** \_ => ())
 36 |
 37 | public export
 38 | fromState : {0 c : Cont} ->
 39 |   State c ->
 40 |   c.Shp
 41 | fromState f = f.fwd ()
 42 |
 43 | public export
 44 | fromCostate : {0 c : Cont} ->
 45 |   Costate c ->
 46 |   (x : c.Shp) -> c.Pos x
 47 | fromCostate f x = f.bwd x ()
 48 |
 49 | public export
 50 | toCostate : {0 c : Cont} ->
 51 |   ((x : c.Shp) -> c.Pos x) ->
 52 |   Costate c
 53 | toCostate s = !% \x => (() ** \() => s x)
 54 |
 55 | public export
 56 | fromNapCostateToState : {0 c : Cont} ->
 57 |   Costate (Nap c.Shp) -> State c
 58 | fromNapCostateToState f = toState (f.bwd () ())
 59 |
 60 | public export
 61 | fromStateToNapCostate : {0 c : Cont} ->
 62 |   State c -> Costate (Nap c.Shp)
 63 | fromStateToNapCostate f = toCostate f.fwd
 64 |
 65 | public export
 66 | pushDown : Cont -> Cont
 67 | pushDown c = Const2 Unit c.Shp
 68 |
 69 | public export
 70 | pushIntoContinuation : {d, p, l : Cont} ->
 71 |   (d >< p =%> l) ->
 72 |   (p =%> (pushDown d) >@ l)
 73 | pushIntoContinuation f = !% \p => (() <| \d => f.fwd (d, p) **
 74 |   \(d ** l'=> snd $ f.bwd (d, p) l')
 75 |
 76 |
 77 | namespace HancockTensorProduct
 78 |   public export
 79 |   leftUnit : (Scalar >< c) =%> c
 80 |   leftUnit = !% \((), s) => (s ** \p => ((), p))
 81 |   
 82 |   public export
 83 |   rightUnit : (c >< Scalar) =%> c
 84 |   rightUnit = !% \(x, ()) => (x ** \x' => (x', ()))
 85 |
 86 |   public export
 87 |   leftUnitInv : c =%> (Scalar >< c)
 88 |   leftUnitInv = !% \x => (((), x) ** \((), x') => x')
 89 |
 90 |   public export
 91 |   rightUnitInv : c =%> (c >< Scalar)
 92 |   rightUnitInv = !% \x => ((x, ()) ** \(x', ()) => x')
 93 |
 94 |   public export
 95 |   assocL : ((a >< b) >< c) =%> (a >< (b >< c))
 96 |   assocL = !% \((a, b), c) => ((a, (b, c)) ** \(a', (b', c')) => ((a', b'), c'))
 97 |
 98 |   public export
 99 |   assocR : (a >< (b >< c)) =%> ((a >< b) >< c)
100 |   assocR = !% \(a, (b, c)) => (((a, b), c) ** \((a', b'), c') => (a', (b', c')))
101 |
102 |   public export
103 |   swap : (a >< b) =%> (b >< a)
104 |   swap = !% \(a, b) => ((b, a) ** \(b', a') => (a', b'))
105 |
106 | namespace CompositionProduct
107 |   public export
108 |   leftUnit : (Scalar >@ c) =%> c
109 |   leftUnit = !% \(() <| cShp) => (cShp () ** \c' => (() ** c'))
110 |
111 |   public export
112 |   rightUnit : (c >@ Scalar) =%> c
113 |   rightUnit = !% \(s <| _) => (s ** \cp => (cp ** ()))
114 |
115 |   public export
116 |   leftUnitInv : c =%> (Scalar >@ c)
117 |   leftUnitInv = !% \x => (() <| (\_ => x) ** \(() ** c') => c')
118 |   
119 |   public export
120 |   rightUnitInv : c =%> (c >@ Scalar)
121 |   rightUnitInv = !% \s => (s <| const () ** fst)
122 |
123 | namespace Coproduct
124 |   public export
125 |   elim : {c : Cont} ->
126 |     (c >+< c) =%> c
127 |   elim = !% \case
128 |     (Left x) => (x ** id)
129 |     (Right y) => (y ** id)
130 |
131 |
132 | ||| Interaction between composition and tensor product
133 | public export
134 | duoidal : ((c >@ d) >< (e >@ f)) =%> ((c >< e) >@ (d >< f))
135 | duoidal = !% \((sc <| idxC), (se <| idxE)) =>
136 |   ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) **
137 |     \((cp, ep) ** (dp, fp)) => ((cp ** dp), (ep ** fp)))
138 |
139 | ||| Specific distributive law we need
140 | public export
141 | distribute : ((c >< e) =%> s) ->
142 |   ((c >< (e >@ g)) =%> (s >@ g))
143 | distribute f = (rightUnitInv >< id {a=e >@ g})
144 |              %>> duoidal {d = Scalar}
145 |              %>> (f >@ leftUnit)
146 |
147 | ||| Ext is a functor of type Cont -> [Type, Type]
148 | ||| On objects it maps a container to a polynomial functor
149 | ||| On morphisms it maps a dependent lens to a natural transformation
150 | ||| This is the action on morphisms
151 | public export
152 | extMap : {0 c, d : Cont} ->
153 |   c =%> d ->
154 |   Ext c a -> Ext d a
155 | extMap f (sh <| index) = let (y ** ky= (%!) f sh
156 |                          in y <| (index . ky)
157 |
158 |
159 | ||| Wraps a dependent lens `c =%> d`
160 | ||| into one of type `c >@ Scalar =%> d >@ Scalar`
161 | ||| Needed because `c >@ Scalar` isn't automatically reduced to `c`
162 | public export
163 | wrapIntoVector : {c, d : Cont} ->
164 |   c =%> d ->
165 |   Tensor [c] =%> Tensor [d]
166 | wrapIntoVector (!% f) =
167 |   !% \e => let (y ** ky= f (shapeExt e)
168 |            in (y <| \_ => () ** \(cp ** ()=> (ky cp ** ()))
169 |
170 | ||| Wraps a dependent lens `c =%> d`
171 | ||| into one of type `c >< Scalar =%> d >< Scalar`
172 | ||| Needed because `c >< Scalar` isn't automatically reduced to `c`
173 | public export
174 | wrapIntoVectorHancock : {c, d : Cont} ->
175 |   c =%> d ->
176 |   HancockTensor [c] =%> HancockTensor [d]
177 | wrapIntoVectorHancock f = !% \(x, ()) =>
178 |   ((f.fwd x, ()) ** \(y', ()) => (f.bwd x y', ())
179 |
180 | namespace CubicalHelpers
181 |   ||| Helper function allowing `shape` in `cubicalShape` to have zero annotation
182 |   public export
183 |   cubicalShapeHelper : All IsCubical shape -> List Nat
184 |   cubicalShapeHelper [] = []
185 |   cubicalShapeHelper (ic :: ics) = dimHelper ic :: cubicalShapeHelper ics
186 |     
187 |   ||| Given a list of cubical containers, return the list of their dimensions
188 |   public export
189 |   cubicalShape : (0 shape : List Cont) -> All IsCubical shape => List Nat
190 |   cubicalShape _ @{ac} = cubicalShapeHelper ac
191 |     
192 |   ||| Size of a list of cubical containers is the product of their dimensions
193 |   public export
194 |   size : (0 shape : List Cont) -> All IsCubical shape => Nat
195 |   size shape = prod (cubicalShape shape)
196 |
197 | ||| Layout-aware dependent lens flattening a cubical tensor
198 | public export
199 | flattenCubical : {shape : List Cont} ->
200 |   (ac : All IsCubical shape) =>
201 |   LayoutOrder ->
202 |   Tensor shape =%> Vect (size shape)
203 | flattenCubical {shape = [], ac=[]} _ = !% \() => (() ** \FZ => ())
204 | flattenCubical {shape = (_ :: ss), ac=(MkIsCubical n :: as)} lo
205 |   = !% \(() <| t) => (() ** \idx =>
206 |       let (!% recBackward) = flattenCubical {shape = ss} lo
207 |           (i, rest) = splitFinProd lo idx
208 |           (_ ** backRec= recBackward (t i)
209 |       in (i ** backRec rest))
210 |
211 | ||| Layout-aware dependent lens unflattening a tensor
212 | public export
213 | unflattenCubical : {shape : List Cont} ->
214 |   (ac : All IsCubical shape) =>
215 |   LayoutOrder ->
216 |   Vect (size shape) =%> Tensor shape
217 | unflattenCubical {shape = [], ac=[]} lo = !% \() => (() ** \() => FZ)
218 | unflattenCubical {shape = (_ :: ss), ac=((MkIsCubical n) :: as)} lo =
219 |   let (!% f) = unflattenCubical {shape = ss} lo
220 |       (innerShape ** innerBack= f ()
221 |   in !% \() => ((() <| \_ => innerShape) ** (\(cp ** restPos=>
222 |     indexFinProd lo cp (innerBack restPos)))
223 |
224 | ||| This is simply a rewrite!
225 | public export
226 | recastFlattenedTensor : {oldShape, newShape : List Cont} ->
227 |   (oldAc : All IsCubical oldShape) => (newAc : All IsCubical newShape) =>
228 |   {auto prf : size oldShape = size newShape} ->
229 |   Vect (size oldShape) =%> Vect (size newShape)
230 | recastFlattenedTensor = !% \() => (() ** \i => rewrite prf in i)
231 |
232 | ||| Reshapes a cubical tensor by first flattening it to a linear representation,
233 | ||| casting the type to the new shape, and then unflattening it back
234 | ||| Is generic over layout order
235 | public export
236 | reshape : {oldShape, newShape : List Cont} ->
237 |   (oldAc : All IsCubical oldShape) => (newAc : All IsCubical newShape) =>
238 |   LayoutOrder ->
239 |   {auto prf : size oldShape = size newShape} ->
240 |   Tensor oldShape =%> Tensor newShape
241 | reshape lo = flattenCubical lo
242 |          %>> recastFlattenedTensor
243 |          %>> unflattenCubical lo
244 |
245 |
246 | namespace Transpose
247 |   public export
248 |   transposeLens : IsNaperian c => IsNaperian d => (c >@ d) =%> (d >@ c)
249 |   transposeLens @{MkIsNaperian _} @{MkIsNaperian _} = !% \(() <| _) =>
250 |     (() <| (\_ => ()) ** \(dInd ** cInd=> (cInd ** dInd))
251 |
252 |   ||| This and the above function should be one and the same, up to rebracketing
253 |   public export
254 |   transpose : IsNaperian c => IsNaperian d =>
255 |     Tensor [c, d] =%> Tensor [d, c]
256 |   transpose @{MkIsNaperian _} @{MkIsNaperian _} = !% \(() <| _) =>
257 |     (() <| (\_ => () <| (\_ => ())) ** \(dInd ** cInd ** ()=>
258 |       (cInd ** (dInd ** ())))
259 |
260 | ||| Functionality for transforming a tensor into a hancock tensor
261 | namespace TransformIntoHancockTensor
262 |   public export
263 |   hancockTensorNaperianShape : {shape : List Cont} ->
264 |     (allNap : All IsNaperian shape) =>
265 |     (HancockTensor shape).Shp
266 |   hancockTensorNaperianShape {shape = []} = ()
267 |   hancockTensorNaperianShape {allNap = ((MkIsNaperian _) :: _)}
268 |     = ((), hancockTensorNaperianShape)
269 |   
270 |   ||| Helper to compute the unique shape of Tensor when all containers are Naperian
271 |   public export
272 |   tensorNaperianShape : {shape : List Cont} ->
273 |     (allNap : All IsNaperian shape) =>
274 |     (Tensor shape).Shp
275 |   tensorNaperianShape {shape = []} = ()
276 |   tensorNaperianShape {shape = (_ :: ss), allNap = ((MkIsNaperian _) :: ns)}
277 |     = () <| \_ => tensorNaperianShape {shape = ss} @{ns}
278 |   
279 |   ||| Analogous to `naperianPosEq` but for the HancockTensor structure
280 |   ||| We can't use `naperianPosEq` directly because the shape of the resulting
281 |   ||| container is not Unit, it is only isomorphic to it
282 |   public export
283 |   hancockTensorPosEq : {shape : List Cont} ->
284 |     (allNap : All IsNaperian shape) =>
285 |     {0 x, y : (HancockTensor shape).Shp} ->
286 |     (HancockTensor shape).Pos x = (HancockTensor shape).Pos y
287 |   hancockTensorPosEq {allNap = []} = Refl
288 |   hancockTensorPosEq {allNap = ((MkIsNaperian _) :: _)} = cong2 Pair
289 |     (naperianPosEq @{MkIsNaperian _} {x=()} {y=()})
290 |     hancockTensorPosEq
291 |   
292 |   ||| Tensor shape is isomorphic to HancockTensor shape when all containers in
293 |   ||| the shape are Naperian. This is one arrow of that isomorphism
294 |   public export
295 |   transformToHancock : {shape : List Cont} ->
296 |     All IsNaperian shape =>
297 |     Tensor shape =%> HancockTensor shape
298 |   transformToHancock {shape = []} = id
299 |   transformToHancock {shape = (_ :: _)} @{((MkIsNaperian _) :: _)}
300 |     = !% \(() <| content) => (((), hancockTensorNaperianShape) **
301 |        \(p, restPos) =>
302 |          let (_ ** recBack= (%!) transformToHancock (content p)
303 |          in (p ** recBack $ replace {p = id} hancockTensorPosEq restPos))
304 |
305 |
306 | public export
307 | EmptyExtEq : {0 c : Cont} -> IsNaperian c => Ext c Unit = Unit
308 | EmptyExtEq @{(MkIsNaperian pos)} = believe_me () -- what does wrong if we do this
309 |
310 |     
311 |
312 |   -- ||| Technically this is Unit, but hard to prove
313 |   -- public export
314 |   -- foldOverNaperianShapeComp : {shape : List Cont} ->
315 |   --   (allNap : All IsNaperian shape) =>
316 |   --   (Tensor shape).Shp
317 |   -- foldOverNaperianShapeComp {shape = []} = ()
318 |   -- foldOverNaperianShapeComp {allNap = ((MkIsNaperian pos) :: ns)}
319 |   --   = () <| \_ => foldOverNaperianShapeComp
320 |   -- 
321 |   -- public export
322 |   -- naperianHancockShape : {shape : List Cont} ->
323 |   --   (allNap : All IsNaperian shape) =>
324 |   --   (HancockTensor shape).Shp = Unit
325 |   -- naperianHancockShape = believe_me ()
326 |   -- 
327 |   -- public export
328 |   -- foldOverNaperianShapeHancock : {shape : List Cont} ->
329 |   --   (allNap : All IsNaperian shape) =>
330 |   --   (HancockTensor shape).Shp
331 |   -- foldOverNaperianShapeHancock {shape = []} = ()
332 |   -- foldOverNaperianShapeHancock {allNap = ((MkIsNaperian _) :: _)}
333 |   --   = ((), foldOverNaperianShapeHancock)
334 |
335 |
336 | -- public export
337 | -- tensorIsNaperianShape : {shape : List Cont} ->
338 | --   (allNap : All IsNaperian shape) =>
339 | --   IsNaperian (Tensor shape)
340 | -- tensorIsNaperianShape {shape = []} = MkIsNaperian ()
341 | -- tensorIsNaperianShape {shape = (_ :: ss), allNap = ((MkIsNaperian pos) :: ns)}
342 | --   = let tg = tensorIsNaperianShape {shape = ss} 
343 | --     in ?tensorIsNaperianShape_rhs_1
344 | --     --in rewrite naperianShpEq @{tg}
345 | --     --in (rewrite (EmptyExtEq {c=(Nap pos)})
346 | --     --in let tg = MkIsNaperian in ?tensorIsNaperianShape_rhs_2)
347 |
348 | -- public export
349 | -- transformToHancock : {shape : List Cont} ->
350 | --   All IsNaperian shape =>
351 | --   Tensor shape =%> HancockTensor shape
352 | -- transformToHancock {shape = []} = id
353 | -- transformToHancock {shape = (_ :: ss)} @{((MkIsNaperian pos) :: ns)}
354 | --   = let f = (%!) (transformToHancock {shape = ss} @{ns})
355 | --         (_ ** h) = f (foldOverNaperianShapeComp {shape=ss})
356 | --     in !% \(() <| content) => (((), foldOverNaperianShapeHancock) **
357 | --       \(p, fld) => (p ** ?hhh))
358 | --       -- (((), rewrite -- foldOverNaperianShapeHancock {shape=ss} @{ns} in ()) **
359 | --     --   \(p, fld) => (p ** ?bnn))
360 |
361 | -- need to organise this
362 | namespace BinTree
363 |   public export
364 |   inorderBackward : (b : BinTreeShape) ->
365 |     Fin (numNodesAndLeaves b) ->
366 |     BinTreePos b
367 |   inorderBackward LeafS FZ = AtLeaf
368 |   inorderBackward (NodeS lt rt) n with (strengthenN {m=numNodesAndLeaves lt} n)
369 |      _ | Left p = GoLeft (inorderBackward lt p)
370 |      _ | Right FZ = AtNode
371 |      _ | Right (FS g) = GoRight (inorderBackward rt g)
372 |
373 |
374 |   public export
375 |   inorder : BinTree =%> List
376 |   inorder = !% \b => (numNodesAndLeaves b ** inorderBackward b)
377 |
378 | namespace BinTreeNode
379 |   public export
380 |   inorderBackward : (b : BinTreeShape) ->
381 |     Fin (numNodes b) ->
382 |     BinTreePosNode b
383 |   inorderBackward (NodeS lt rt) n with (strengthenN {m=numNodes lt} n)
384 |     _ | Left p = GoLeft (inorderBackward lt p)
385 |     _ | Right FZ = AtNode
386 |     _ | Right (FS g) = GoRight (inorderBackward rt g)
387 |
388 |   ||| Traverses a binary tree container in order, producing a list container
389 |   public export
390 |   inorder : BinTreeNode =%> List
391 |   inorder = !% \b => (numNodes b ** inorderBackward b)
392 |
393 |   -- Need to do some rewriting for preorder
394 |   public export
395 |   preorderBinTreeNode : (b : BinTreeShape) ->
396 |     Fin (numNodes b) -> BinTreePosNode b
397 |   preorderBinTreeNode (NodeS lt rt) x = ?preorderBinTreeNode_rhs_1
398 |   --preorderBinTreeNode (NodeS lt rt) n with (strengthenN {m=numNodes lt} n)
399 |   --  _ | Left p = ?whl
400 |   --  _ | Right FZ = ?whn
401 |   --  _ | Right (FS g) = ?whr
402 |
403 | -- public export
404 | -- traverseLeaf : (x : BinTreeShape) -> FinBinTreeLeaf x -> Fin (numLeaves x)
405 | -- traverseLeaf LeafS Done = FZ
406 | -- traverseLeaf (NodeS lt rt) (GoLeft x) = weakenN (numLeaves rt) (traverseLeaf lt x)
407 | -- traverseLeaf (NodeS lt rt) (GoRight x) = shift (numLeaves lt) (traverseLeaf rt x)
408 | -- 
409 |
410 | public export
411 | maybeToList : Maybe =%> List
412 | maybeToList = !% \b => case b of 
413 |   False => (0 ** absurd)
414 |   True => (1 ** \_ => ())
415 |
416 | public export
417 | Sample : MonadSample m => {n : Nat} -> IsSucc n =>
418 |   (m <!> Sample n) =%> Scalar
419 | Sample = toCostate sample
420 |
421 | -- TODO here maybe need to uncomment during merge?
422 | -- public export
423 | -- selectShape : {cs : Vect k Cont} ->
424 | --   (shapes : All Shp cs) -> (i : Fin k) -> Any Shp cs
425 | -- selectShape (s :: _) FZ = Here s
426 | -- selectShape (_ :: ss) (FS j) = There (selectShape ss j)
427 | -- 
428 | -- ||| Extract the position from an AnyPos at a given index
429 | -- public export
430 | -- extractPos : {n : Nat} -> {xs : Vect n Cont} ->
431 | --   {shapes : All Shp xs} ->
432 | --   (i : Fin n) ->
433 | --   AnyShpPos (selectShape shapes i) ->
434 | --   AnyPos shapes
435 | -- extractPos {shapes = (_ :: _)} FZ (Here x) = Here x
436 | -- extractPos {shapes = (_ :: _)} (FS j) (There rest)
437 | --   = There $ extractPos j rest
438 | -- 
439 | -- public export
440 | -- SampleAndChoose : {n : Nat} -> {xs : Vect n Cont} ->
441 | --   ConvexComb xs =%> (Sample n >@ Any xs)
442 | -- SampleAndChoose = !% \(d, shapes) =>
443 | --   (d <| selectShape shapes ** \(i ** grad) => (0, [extractPos i grad]))
444 |
445 | -- SampleAndChooseWithDist = !% \(d, shapes) =>
446 | --   (d <| electShape shapes ** \(i ** grad) => (0, [(i ** extractPos i grad)]))
447 |
448 | -- public export
449 | -- GetDist : {n : Nat} -> {xs : Vect n Cont} ->
450 | --   ConvexComb xs =%> Simplex n
451 | -- GetDist = !% \(d, shapes) => (d ** \d' => (d', ?GetDist_rhs))
452 |
453 | public export
454 | handleEffect : Monad m =>
455 |   (handler : (m <!> effect) =%> Scalar) ->
456 |   (program : a =%> effect) ->
457 |   m <!> a =%> Scalar
458 | handleEffect handler program = !% \x =>
459 |   let (ef ** nn= (%! program) x
460 |       (() ** rest= (%! handler) ef
461 |   in (() ** \() => do 
462 |     e <- rest ()
463 |     pure (nn e))
464 |