0 | module Data.Container.Base.Morphism.Instances
  1 |
  2 | import Data.Fin
  3 | import Data.Fin.Split
  4 | import Data.Vect
  5 | import Data.List.Elem
  6 | import Data.List.Quantifiers
  7 |
  8 | import Data.Container.Base.Object.Definition
  9 | import Data.Container.Base.Morphism.Definition
 10 | import Data.Container.Base.Extension.Definition
 11 | import Data.Container.Base.Properties.Definitions
 12 | import Data.Container.Base.Product.Definitions
 13 |
 14 | import Data.Container.Base.Object.Instances
 15 |
 16 | import Data.Container.Base.Quantifiers
 17 | import Data.Container.Base.TreeUtils
 18 |
 19 | import Control.Monad.Distribution
 20 | import Control.Monad.Sample.Definition
 21 |
 22 | import Data.Num
 23 | import Data.Layout
 24 | import Misc
 25 |
 26 | namespace State
 27 |   ||| "State" as defined in https://arxiv.org/abs/2403.13001 and open games 
 28 |   ||| Given a shape of any container, state can be defined
 29 |   public export
 30 |   State : Cont -> Type
 31 |   State c = Scalar =%> c
 32 |
 33 |   public export
 34 |   toState : {0 c : Cont} -> (x : c.Shp) -> State c
 35 |   toState x = !% \() => (x ** \_ => ())
 36 |
 37 |   public export
 38 |   fromState : {0 c : Cont} ->
 39 |     State c ->
 40 |     c.Shp
 41 |   fromState f = f.fwd ()
 42 |
 43 | namespace Costate
 44 |   public export
 45 |   Costate : Cont -> Type
 46 |   Costate c = c =%> Scalar
 47 |
 48 |   public export
 49 |   fromCostate : {0 c : Cont} ->
 50 |     Costate c ->
 51 |     (x : c.Shp) -> c.Pos x
 52 |   fromCostate f x = f.bwd x ()
 53 |   
 54 |   public export
 55 |   toCostate : {0 c : Cont} ->
 56 |     ((x : c.Shp) -> c.Pos x) ->
 57 |     Costate c
 58 |   toCostate s = !% \x => (() ** \() => s x)
 59 |
 60 | public export
 61 | fromNapCostateToState : Costate (Nap c.Shp) -> State c
 62 | fromNapCostateToState f = toState (f.bwd () ())
 63 |
 64 | public export
 65 | fromStateToNapCostate : State c -> Costate (Nap c.Shp)
 66 | fromStateToNapCostate f = toCostate f.fwd
 67 |
 68 | public export
 69 | pushDown : Cont -> Cont
 70 | pushDown c = Const2 Unit c.Shp
 71 |
 72 | public export
 73 | pushIntoContinuation : {0 d, p, l : Cont} ->
 74 |   (d >< p =%> l) ->
 75 |   (p =%> (pushDown d) >@ l)
 76 | pushIntoContinuation f = !% \p => (() <| \d => f.fwd (d, p) **
 77 |   \(d ** l'=> snd $ f.bwd (d, p) l')
 78 |
 79 |
 80 | namespace CategoricalProduct
 81 |   public export
 82 |   terminal : c =%> UnitCont
 83 |   terminal = !% \_ => (() ** absurd)
 84 |
 85 |
 86 | namespace HancockTensorProduct
 87 |   public export
 88 |   leftUnit : Scalar >< c =%> c
 89 |   leftUnit = !% \((), s) => (s ** \p => ((), p))
 90 |   
 91 |   public export
 92 |   rightUnit : c >< Scalar =%> c
 93 |   rightUnit = !% \(x, ()) => (x ** \x' => (x', ()))
 94 |
 95 |   public export
 96 |   leftUnitInv : c =%> Scalar >< c
 97 |   leftUnitInv = !% \x => (((), x) ** \((), x') => x')
 98 |
 99 |   public export
100 |   rightUnitInv : c =%> c >< Scalar
101 |   rightUnitInv = !% \x => ((x, ()) ** \(x', ()) => x')
102 |
103 |   public export
104 |   assocL : (a >< b) >< c =%> a >< (b >< c)
105 |   assocL = !% \((a, b), c) => ((a, (b, c)) ** \(a', (b', c')) => ((a', b'), c'))
106 |
107 |   public export
108 |   assocR : a >< (b >< c) =%> (a >< b) >< c
109 |   assocR = !% \(a, (b, c)) => (((a, b), c) ** \((a', b'), c') => (a', (b', c')))
110 |
111 |   public export
112 |   swap : a >< b =%> b >< a
113 |   swap = !% \(a, b) => ((b, a) ** \(b', a') => (a', b'))
114 |
115 | namespace CompositionProduct
116 |   public export
117 |   leftUnit : Scalar >@ c =%> c
118 |   leftUnit = !% \(() <| cShp) => (cShp () ** \c' => (() ** c'))
119 |
120 |   public export
121 |   rightUnit : c >@ Scalar =%> c
122 |   rightUnit = !% \(s <| _) => (s ** \cp => (cp ** ()))
123 |
124 |   public export
125 |   leftUnitInv : c =%> Scalar >@ c
126 |   leftUnitInv = !% \x => (() <| (\_ => x) ** \(() ** c') => c')
127 |   
128 |   public export
129 |   rightUnitInv : c =%> c >@ Scalar
130 |   rightUnitInv = !% \s => (s <| const () ** fst)
131 |
132 | namespace Coproduct
133 |   public export
134 |   elim : c >+< c =%> c
135 |   elim = !% \case
136 |     Left x => (x ** id)
137 |     Right y => (y ** id)
138 |
139 |   public export
140 |   initial : Empty =%> c
141 |   initial = !% absurd
142 |
143 |
144 |
145 | namespace CartesianClosure
146 |   ||| The following is the proof that for any container `c` there is an
147 |   ||| isomorphism in `Cont` between `c` and `CartesianClosure UnitCont c`
148 |   ||| This holds in any monoidal closed category: `X ≅ [I, X]`
149 |   namespace StateIsomorphismProof
150 |     stateToCartClosureFw : c =%> (CartesianClosure UnitCont c)
151 |     stateToCartClosureFw = !% \cShp => (!% \() => (cShp ** \_ => Nothing)
152 |                                        ** \(() ** cPos ** ItIsNothing) => cPos)
153 |
154 |     stateToCartClosureBw : CartesianClosure UnitCont c =%> c
155 |     stateToCartClosureBw = !% \l => (l.fwd () ** \cPos =>
156 |       (() ** cPos ** maybeVoidIsNothing (l.bwd () cPos)))
157 |
158 |
159 | ||| For a overview of this interaction from the categorical perspective, see
160 | ||| the Poly book (https://arxiv.org/abs/2312.00990) (Section 6.3.4)
161 | namespace CompositionTensorInteraction
162 |   ||| Interaction between composition and tensor product
163 |   ||| Swaps the operations, and middle two containers
164 |   ||| Not an isomorphism!
165 |   public export
166 |   duoidal : (c >@ d) >< (e >@ f) =%> (c >< e) >@ (d >< f)
167 |   duoidal = !% \((sc <| idxC), (se <| idxE)) =>
168 |     ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) **
169 |       \((cp, ep) ** (dp, fp)) => ((cp ** dp), (ep ** fp)))
170 |   
171 |   ||| Tensor product embeds into composition
172 |   ||| A special case of `duoidal`
173 |   public export
174 |   tensorToComp : c >< f =%> c >@ f
175 |   tensorToComp =   (rightUnitInv >< leftUnitInv)
176 |                %>> duoidal {d=Scalar,e=Scalar}
177 |                %>> (rightUnit >@ leftUnit)
178 |
179 |   ||| Going the other way is impossible without any constraints 
180 |   ||| Two possibilities on constraints (this, and `compToTensor2`)
181 |   public export 
182 |   compToTensor : IsNaperian d =>
183 |     (c >@ d) =%> (c >< d)
184 |   compToTensor @{(MkIsNaperian dPos)} = !% \(cShp <| content) =>
185 |     ((cShp,()) ** \(cPos, dPos) => (cPos ** dPos))
186 |   
187 |   public export
188 |   compToTensor2 : IsFlat c =>
189 |     (c >@ d) =%> (c >< d)
190 |   compToTensor2 @{(ItIsFlat cShp)} = !% \(cShp <| dShp) =>
191 |     ((cShp, dShp ()) ** \((), dPos') => (() ** dPos'))
192 |   
193 |   ||| Specific distributive law we need
194 |   public export
195 |   distribute : (c >< e) =%> s ->
196 |     c >< (e >@ g) =%> s >@ g
197 |   distribute f = (rightUnitInv >< id {a=e >@ g})
198 |                %>> duoidal {d = Scalar}
199 |                %>> (f >@ leftUnit)
200 |
201 |
202 | ||| Wraps a dependent lens `c =%> d`
203 | ||| into one of type `c >@ Scalar =%> d >@ Scalar`
204 | ||| Needed because `c >@ Scalar` isn't automatically reduced to `c`
205 | public export
206 | wrapIntoVector : c =%> d ->
207 |   Tensor [c] =%> Tensor [d]
208 | wrapIntoVector f = rightUnit %>> f %>> rightUnitInv
209 |
210 | public export
211 | wrapIntoMatrix : (c >@ c') =%> (d >@ d') ->
212 |   Tensor [c, c'] =%> Tensor [d, d']
213 | wrapIntoMatrix f =   (id >@ rightUnit)
214 |                  %>> f
215 |                  %>> (id >@ rightUnitInv)
216 |
217 | ||| Wraps a dependent lens `c =%> d`
218 | ||| into one of type `c >< Scalar =%> d >< Scalar`
219 | ||| Needed because `c >< Scalar` isn't automatically reduced to `c`
220 | public export
221 | wrapIntoVectorHancock : c =%> d ->
222 |   HancockTensor [c] =%> HancockTensor [d]
223 | wrapIntoVectorHancock f = rightUnit %>> f %>> rightUnitInv
224 |
225 | namespace CubicalHelpers
226 |   ||| Helper function allowing `shape` in `cubicalShape` to have zero annotation
227 |   public export
228 |   cubicalShapeHelper : All IsCubical shape -> List Nat
229 |   cubicalShapeHelper [] = []
230 |   cubicalShapeHelper (ic :: ics) = dimHelper ic :: cubicalShapeHelper ics
231 |     
232 |   ||| Given a list of cubical containers, return the list of their dimensions
233 |   public export
234 |   cubicalShape : (0 shape : List Cont) -> All IsCubical shape => List Nat
235 |   cubicalShape _ @{ac} = cubicalShapeHelper ac
236 |     
237 |   ||| Size of a list of cubical containers is the product of their dimensions
238 |   public export
239 |   size : (0 shape : List Cont) -> All IsCubical shape => Nat
240 |   size shape = prod (cubicalShape shape)
241 |
242 | ||| Layout-aware dependent lens flattening a cubical tensor
243 | public export
244 | flattenCubical : {shape : List Cont} ->
245 |   (ac : All IsCubical shape) =>
246 |   LayoutOrder ->
247 |   Tensor shape =%> Vect (size shape)
248 | flattenCubical {shape = [], ac=[]} _ = !% \() => (() ** \FZ => ())
249 | flattenCubical {shape = (_ :: ss), ac=(MkIsCubical n :: as)} lo
250 |   = !% \(() <| t) => (() ** \idx =>
251 |       let (!% recBackward) = flattenCubical {shape = ss} lo
252 |           (i, rest) = splitFinProd lo idx
253 |           (_ ** backRec= recBackward (t i)
254 |       in (i ** backRec rest))
255 |
256 | ||| Layout-aware dependent lens unflattening a tensor
257 | public export
258 | unflattenCubical : {shape : List Cont} ->
259 |   (ac : All IsCubical shape) =>
260 |   LayoutOrder ->
261 |   Vect (size shape) =%> Tensor shape
262 | unflattenCubical {shape = [], ac=[]} lo = !% \() => (() ** \() => FZ)
263 | unflattenCubical {shape = (_ :: ss), ac=((MkIsCubical n) :: as)} lo =
264 |   let (!% f) = unflattenCubical {shape = ss} lo
265 |       (innerShape ** innerBack= f ()
266 |   in !% \() => ((() <| \_ => innerShape) ** (\(cp ** restPos=>
267 |     indexFinProd lo cp (innerBack restPos)))
268 |
269 | ||| This is simply a rewrite!
270 | public export
271 | recastFlattenedTensor : {oldShape, newShape : List Cont} ->
272 |   (oldAc : All IsCubical oldShape) => (newAc : All IsCubical newShape) =>
273 |   {auto prf : size oldShape = size newShape} ->
274 |   Vect (size oldShape) =%> Vect (size newShape)
275 | recastFlattenedTensor = !% \() => (() ** \i => rewrite prf in i)
276 |
277 | ||| Reshapes a cubical tensor by first flattening it to a linear representation,
278 | ||| casting the type to the new shape, and then unflattening it back
279 | ||| Is generic over layout order
280 | public export
281 | reshape : {oldShape, newShape : List Cont} ->
282 |   (oldAc : All IsCubical oldShape) => (newAc : All IsCubical newShape) =>
283 |   LayoutOrder ->
284 |   {auto prf : size oldShape = size newShape} ->
285 |   Tensor oldShape =%> Tensor newShape
286 | reshape lo = flattenCubical lo
287 |          %>> recastFlattenedTensor
288 |          %>> unflattenCubical lo
289 |
290 |
291 | namespace Transpose
292 |   public export
293 |   transposeLens : IsNaperian c => IsNaperian d => c >@ d =%> d >@ c
294 |   transposeLens @{MkIsNaperian _} @{MkIsNaperian _} = !% \(() <| _) =>
295 |     (() <| (\_ => ()) ** \(dInd ** cInd=> (cInd ** dInd))
296 |
297 |   public export
298 |   transpose : IsNaperian c => IsNaperian d =>
299 |     Tensor [c, d] =%> Tensor [d, c]
300 |   transpose @{MkIsNaperian _} @{MkIsNaperian _} = wrapIntoMatrix transposeLens
301 |
302 |   -- ||| experiment, does this work?
303 |   -- public export
304 |   -- transposeMiddle : IsNaperian c => IsNaperian e =>
305 |   --   Tensor [c, e, d] =%> 
306 |   
307 |
308 |   --||| Transpose a given element to the front of the shape
309 |   --public export
310 |   --transposeToFront : (shape : List Cont) ->
311 |   --  (c : Cont) ->
312 |   --  (elem : Elem c shape) =>
313 |   --  All IsNaperian (dropAfterElem shape elem) =>
314 |   --  Tensor shape =%> Tensor (c :: dropElem shape elem)
315 |   --transposeToFront (_ :: xs) c @{Here} @{allNap} = ?transposeToFront_rhs_0
316 |   --transposeToFront (y :: xs) c @{(There x)} @{allNap} = ?transposeToFront_rhs_1
317 |   
318 | ||| Functionality for transforming a tensor into a hancock tensor
319 | namespace TransformIntoHancockTensor
320 |   public export
321 |   hancockTensorNaperianShape : {shape : List Cont} ->
322 |     (allNap : All IsNaperian shape) =>
323 |     (HancockTensor shape).Shp
324 |   hancockTensorNaperianShape {shape = []} = ()
325 |   hancockTensorNaperianShape {allNap = ((MkIsNaperian _) :: _)}
326 |     = ((), hancockTensorNaperianShape)
327 |   
328 |   ||| Helper to compute the unique shape of Tensor when all containers are Naperian
329 |   public export
330 |   tensorNaperianShape : {shape : List Cont} ->
331 |     (allNap : All IsNaperian shape) =>
332 |     (Tensor shape).Shp
333 |   tensorNaperianShape {shape = []} = ()
334 |   tensorNaperianShape {shape = (_ :: ss), allNap = ((MkIsNaperian _) :: ns)}
335 |     = () <| \_ => tensorNaperianShape {shape = ss} @{ns}
336 |   
337 |   ||| Analogous to `naperianPosEq` but for the HancockTensor structure
338 |   ||| We can't use `naperianPosEq` directly because the shape of the resulting
339 |   ||| container is not Unit, it is only isomorphic to it
340 |   public export
341 |   hancockTensorPosEq : {shape : List Cont} ->
342 |     (allNap : All IsNaperian shape) =>
343 |     {0 x, y : (HancockTensor shape).Shp} ->
344 |     (HancockTensor shape).Pos x = (HancockTensor shape).Pos y
345 |   hancockTensorPosEq {allNap = []} = Refl
346 |   hancockTensorPosEq {allNap = ((MkIsNaperian _) :: _)} = cong2 Pair
347 |     (naperianPosEq @{MkIsNaperian _} {x=()} {y=()})
348 |     hancockTensorPosEq
349 |   
350 |   ||| Tensor shape is isomorphic to HancockTensor shape when all containers in
351 |   ||| the shape are Naperian. This is one arrow of that isomorphism
352 |   public export
353 |   transformToHancock : {shape : List Cont} ->
354 |     All IsNaperian shape =>
355 |     Tensor shape =%> HancockTensor shape
356 |   transformToHancock {shape = []} = id
357 |   transformToHancock {shape = (_ :: _)} @{((MkIsNaperian _) :: _)}
358 |     = !% \(() <| content) => (((), hancockTensorNaperianShape) **
359 |        \(p, restPos) =>
360 |          let (_ ** recBack= (%!) transformToHancock (content p)
361 |          in (p ** recBack $ replace {p = id} hancockTensorPosEq restPos))
362 |
363 |   public export
364 |   transformFromHancock : {shape : List Cont} ->
365 |     All IsNaperian shape =>
366 |     HancockTensor shape =%> Tensor shape
367 |   transformFromHancock {shape = []} = id
368 |   transformFromHancock {shape = (Nap s :: ss)} @{((MkIsNaperian s) :: _)}
369 |     = !% \((), hShp) =>
370 |         let (tShp ** recBack= (%!) transformFromHancock hShp
371 |         in (() <| (\_ => tShp) ** \(p ** restPos=> (p, recBack restPos))
372 |
373 |     
374 |
375 |   -- ||| Technically this is Unit, but hard to prove
376 |   -- public export
377 |   -- foldOverNaperianShapeComp : {shape : List Cont} ->
378 |   --   (allNap : All IsNaperian shape) =>
379 |   --   (Tensor shape).Shp
380 |   -- foldOverNaperianShapeComp {shape = []} = ()
381 |   -- foldOverNaperianShapeComp {allNap = ((MkIsNaperian pos) :: ns)}
382 |   --   = () <| \_ => foldOverNaperianShapeComp
383 |   -- 
384 |   -- public export
385 |   -- naperianHancockShape : {shape : List Cont} ->
386 |   --   (allNap : All IsNaperian shape) =>
387 |   --   (HancockTensor shape).Shp = Unit
388 |   -- naperianHancockShape = believe_me ()
389 |   -- 
390 |   -- public export
391 |   -- foldOverNaperianShapeHancock : {shape : List Cont} ->
392 |   --   (allNap : All IsNaperian shape) =>
393 |   --   (HancockTensor shape).Shp
394 |   -- foldOverNaperianShapeHancock {shape = []} = ()
395 |   -- foldOverNaperianShapeHancock {allNap = ((MkIsNaperian _) :: _)}
396 |   --   = ((), foldOverNaperianShapeHancock)
397 |
398 |
399 | -- public export
400 | -- tensorIsNaperianShape : {shape : List Cont} ->
401 | --   (allNap : All IsNaperian shape) =>
402 | --   IsNaperian (Tensor shape)
403 | -- tensorIsNaperianShape {shape = []} = MkIsNaperian ()
404 | -- tensorIsNaperianShape {shape = (_ :: ss), allNap = ((MkIsNaperian pos) :: ns)}
405 | --   = let tg = tensorIsNaperianShape {shape = ss} 
406 | --     in ?tensorIsNaperianShape_rhs_1
407 | --     --in rewrite naperianShpEq @{tg}
408 | --     --in (rewrite (EmptyExtEq {c=(Nap pos)})
409 | --     --in let tg = MkIsNaperian in ?tensorIsNaperianShape_rhs_2)
410 |
411 | -- public export
412 | -- transformToHancock : {shape : List Cont} ->
413 | --   All IsNaperian shape =>
414 | --   Tensor shape =%> HancockTensor shape
415 | -- transformToHancock {shape = []} = id
416 | -- transformToHancock {shape = (_ :: ss)} @{((MkIsNaperian pos) :: ns)}
417 | --   = let f = (%!) (transformToHancock {shape = ss} @{ns})
418 | --         (_ ** h) = f (foldOverNaperianShapeComp {shape=ss})
419 | --     in !% \(() <| content) => (((), foldOverNaperianShapeHancock) **
420 | --       \(p, fld) => (p ** ?hhh))
421 | --       -- (((), rewrite -- foldOverNaperianShapeHancock {shape=ss} @{ns} in ()) **
422 | --     --   \(p, fld) => (p ** ?bnn))
423 |
424 | -- need to organise this
425 | namespace BinTree
426 |   public export
427 |   inorderBackward : (b : BinTreeShape) ->
428 |     Fin (numNodesAndLeaves b) ->
429 |     BinTreePos b
430 |   inorderBackward LeafS FZ = AtLeaf
431 |   inorderBackward (NodeS lt rt) n with (strengthenN {m=numNodesAndLeaves lt} n)
432 |      _ | Left p = GoLeft (inorderBackward lt p)
433 |      _ | Right FZ = AtNode
434 |      _ | Right (FS g) = GoRight (inorderBackward rt g)
435 |
436 |
437 |   public export
438 |   inorder : BinTree =%> List
439 |   inorder = !% \b => (numNodesAndLeaves b ** inorderBackward b)
440 |
441 | namespace BinTreeNode
442 |   public export
443 |   inorderBackward : (b : BinTreeShape) ->
444 |     Fin (numNodes b) ->
445 |     BinTreePosNode b
446 |   inorderBackward (NodeS lt rt) n with (strengthenN {m=numNodes lt} n)
447 |     _ | Left p = GoLeft (inorderBackward lt p)
448 |     _ | Right FZ = AtNode
449 |     _ | Right (FS g) = GoRight (inorderBackward rt g)
450 |
451 |   ||| Traverses a binary tree container in order, producing a list container
452 |   public export
453 |   inorder : BinTreeNode =%> List
454 |   inorder = !% \b => (numNodes b ** inorderBackward b)
455 |
456 |   -- Need to do some rewriting for preorder
457 |   public export
458 |   preorderBinTreeNode : (b : BinTreeShape) ->
459 |     Fin (numNodes b) -> BinTreePosNode b
460 |   preorderBinTreeNode (NodeS lt rt) x = ?preorderBinTreeNode_rhs_1
461 |   --preorderBinTreeNode (NodeS lt rt) n with (strengthenN {m=numNodes lt} n)
462 |   --  _ | Left p = ?whl
463 |   --  _ | Right FZ = ?whn
464 |   --  _ | Right (FS g) = ?whr
465 |
466 | namespace BinTreeLeaf
467 |   public export
468 |   inorderBackward : (b : BinTreeShape) ->
469 |     Fin (numLeaves b) ->
470 |     BinTreePosLeaf b
471 |   inorderBackward LeafS 0 = AtLeaf
472 |   inorderBackward (NodeS lt rt) i with (strengthenN {m=numLeaves lt} i)
473 |     _ | (Left indLeft) = GoLeft (inorderBackward lt indLeft)
474 |     _ | (Right indRight) = GoRight (inorderBackward rt indRight)
475 |
476 |   public export
477 |   inorder : BinTreeLeaf =%> List
478 |   inorder = !% \b => (numLeaves b ** inorderBackward b)
479 |
480 | -- public export
481 | -- traverseLeaf : (x : BinTreeShape) -> FinBinTreeLeaf x -> Fin (numLeaves x)
482 | -- traverseLeaf LeafS Done = FZ
483 | -- traverseLeaf (NodeS lt rt) (GoLeft x) = weakenN (numLeaves rt) (traverseLeaf lt x)
484 | -- traverseLeaf (NodeS lt rt) (GoRight x) = shift (numLeaves lt) (traverseLeaf rt x)
485 | -- 
486 |
487 | public export
488 | vectToList : {n : Nat} -> Vect n =%> List
489 | vectToList = !% \() => (n ** id)
490 |
491 | public export
492 | maybeToList : Maybe =%> List
493 | maybeToList = !% \b => case b of 
494 |   False => (0 ** absurd)
495 |   True => (1 ** \_ => ())
496 |
497 | public export
498 | Sample : MonadSample m => {n : Nat} -> IsSucc n =>
499 |   (m <!> Sample n) =%> Scalar
500 | Sample = toCostate sample
501 |
502 | -- TODO here maybe need to uncomment during merge?
503 | -- public export
504 | -- selectShape : {cs : Vect k Cont} ->
505 | --   (shapes : All Shp cs) -> (i : Fin k) -> Any Shp cs
506 | -- selectShape (s :: _) FZ = Here s
507 | -- selectShape (_ :: ss) (FS j) = There (selectShape ss j)
508 | -- 
509 | -- ||| Extract the position from an AnyPos at a given index
510 | -- public export
511 | -- extractPos : {n : Nat} -> {xs : Vect n Cont} ->
512 | --   {shapes : All Shp xs} ->
513 | --   (i : Fin n) ->
514 | --   AnyShpPos (selectShape shapes i) ->
515 | --   AnyPos shapes
516 | -- extractPos {shapes = (_ :: _)} FZ (Here x) = Here x
517 | -- extractPos {shapes = (_ :: _)} (FS j) (There rest)
518 | --   = There $ extractPos j rest
519 | -- 
520 | -- public export
521 | -- SampleAndChoose : {n : Nat} -> {xs : Vect n Cont} ->
522 | --   ConvexComb xs =%> (Sample n >@ Any xs)
523 | -- SampleAndChoose = !% \(d, shapes) =>
524 | --   (d <| selectShape shapes ** \(i ** grad) => (0, [extractPos i grad]))
525 |
526 | -- SampleAndChooseWithDist = !% \(d, shapes) =>
527 | --   (d <| electShape shapes ** \(i ** grad) => (0, [(i ** extractPos i grad)]))
528 |
529 | -- public export
530 | -- GetDist : {n : Nat} -> {xs : Vect n Cont} ->
531 | --   ConvexComb xs =%> Simplex n
532 | -- GetDist = !% \(d, shapes) => (d ** \d' => (d', ?GetDist_rhs))
533 |
534 | public export
535 | handleEffect : Monad m =>
536 |   (handler : (m <!> effect) =%> Scalar) ->
537 |   (program : a =%> effect) ->
538 |   m <!> a =%> Scalar
539 | handleEffect handler program = !% \x =>
540 |   let (ef ** nn= (%! program) x
541 |       (() ** rest= (%! handler) ef
542 |   in (() ** \() => do 
543 |     e <- rest ()
544 |     pure (nn e))
545 |