0 | module Control.Monad.Bayes.Traced.Basic
2 | import public Data.Vect
3 | import Control.Monad.Bayes.Interface
4 | import Control.Monad.Bayes.Traced.Common
5 | import Control.Monad.Bayes.Weighted
6 | import Control.Monad.Bayes.Free
7 | import Control.Monad.Free
12 | record Traced (m : Type -> Type) (a : Type) where
13 | constructor MkTraced
14 | model : Weighted (FreeSampler Identity) a
15 | traceDist : m (Trace a)
18 | Monad m => Functor (Traced m) where
19 | map f (MkTraced m d) = MkTraced (map f m) (map (map f) d)
22 | Monad m => Applicative (Traced m) where
23 | pure x = MkTraced (pure x) (pure (pure x))
24 | (MkTraced mf df) <*> (MkTraced mx dx) = MkTraced (mf <*> mx) ((map (<*>) df) <*> dx)
27 | Monad m => Monad (Traced m) where
28 | (MkTraced mx dx) >>= f = MkTraced my dy
30 | my : Weighted (FreeSampler Identity) b
31 | my = mx >>= (model . f)
33 | dy = dx `bind` (traceDist . f)
36 | MonadSample m => MonadSample (Traced m) where
37 | random = MkTraced random (map singleton random)
40 | MonadCond m => MonadCond (Traced m) where
41 | score w = MkTraced (score w) (score w >> pure (scored w))
44 | MonadInfer m => MonadInfer (Traced m) where
47 | hoistT : (forall x. m x -> m x) -> Traced m a -> Traced m a
48 | hoistT f (MkTraced m d) = MkTraced m (f d)
52 | marginal : Monad m => Traced m a -> m a
53 | marginal (MkTraced _ d) = map output d
57 | mhStep : MonadSample m => Traced m a -> Traced m a
58 | mhStep (MkTraced m d) = MkTraced m (d >>= mhTrans' m)
63 | mh : MonadSample m => (n : Nat) -> Traced m a -> m (Vect (S n) a)
64 | mh n (MkTraced mod d) = map (map output) (f n)
66 | f : (n : Nat) -> m (Vect (S n) (Trace a))