0 | module Control.Monad.Bayes.Traced.Dynamic
2 | import public Data.Vect
3 | import Control.Monad.Bayes.Interface
4 | import Control.Monad.Bayes.Traced.Common
5 | import Control.Monad.Bayes.Weighted
6 | import Control.Monad.Bayes.Free
7 | import Control.Monad.Free
11 | record Traced (m : Type -> Type) (a : Type) where
12 | constructor MkTraced
13 | runTraced : m (Weighted (FreeSampler m) a, Trace a)
16 | pushM : Monad m => m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
17 | pushM = join . lift . lift
20 | Monad m => Functor (Traced m) where
21 | map f (MkTraced c) = MkTraced $
do
28 | Monad m => Applicative (Traced m) where
29 | pure x = MkTraced $
pure (pure x, pure x)
30 | (MkTraced cf) <*> (MkTraced cx) = MkTraced $
do
33 | pure (mf <*> mx, tf <*> tx)
36 | Monad m => Monad (Traced m) where
37 | (MkTraced cx) >>= f = MkTraced $
do
39 | let m = mx >>= pushM . map fst . runTraced . f
40 | t <- pure tx `bind` (map snd . runTraced . f)
44 | MonadTrans Traced where
45 | lift m = MkTraced $
map (\a => (lift $
lift m, pure a)) m
48 | MonadSample m => MonadSample (Traced m) where
49 | random = MkTraced $
map (\r => (random, singleton r)) random
52 | MonadCond m => MonadCond (Traced m) where
53 | score w = MkTraced $
map (score w,) (score w >> pure (scored w))
56 | MonadInfer m => MonadInfer (Traced m) where
59 | hoistT : (forall x. m x -> m x) -> Traced m a -> Traced m a
60 | hoistT f (MkTraced c) = MkTraced (f c)
64 | marginal : Monad m => Traced m a -> m a
65 | marginal (MkTraced c) = map (output . snd) c
69 | freeze : Monad m => Traced m a -> Traced m a
70 | freeze (MkTraced c) = MkTraced $
do
73 | pure (pure x, pure x)
77 | mhStep : MonadSample m => Traced m a -> Traced m a
78 | mhStep (MkTraced c) = MkTraced $
do
86 | mh : MonadSample m => (n : Nat) -> Traced m a -> m (Vect (S n) a)
87 | mh n (MkTraced c) = do
89 | let f : (n : Nat) -> m (Vect (S n) (Trace a))
91 | f (S k) = do (x :: xs) <- f k
94 | map (map output) (f n)