0 | module Control.Monad.Bayes.Traced.Basic
 1 |
 2 | import public Data.Vect
 3 | import Control.Monad.Bayes.Interface
 4 | import Control.Monad.Bayes.Traced.Common
 5 | import Control.Monad.Bayes.Weighted
 6 | import Control.Monad.Bayes.Free
 7 | import Control.Monad.Free
 8 |
 9 | ||| A tracing monad where only a subset of random choices are traced.
10 | ||| The random choices that are not to be traced should be lifted from the transformed monad.
11 | public export
12 | record Traced (m : Type -> Type) (a : Type) where 
13 |   constructor MkTraced
14 |   model     : Weighted (FreeSampler Identity) a
15 |   traceDist : m (Trace a)
16 |
17 | export
18 | Monad m => Functor (Traced m) where
19 |   map f (MkTraced m d) = MkTraced (map f m) (map (map f) d)
20 |
21 | export
22 | Monad m => Applicative (Traced m) where
23 |   pure x = MkTraced (pure x) (pure (pure x))
24 |   (MkTraced mf df) <*> (MkTraced mx dx) = MkTraced (mf <*> mx) ((map (<*>) df) <*> dx) 
25 |
26 | export
27 | Monad m => Monad (Traced m) where
28 |   (MkTraced mx dx) >>= f = MkTraced my dy
29 |     where
30 |       my : Weighted (FreeSampler Identity) b
31 |       my = mx >>= (model . f)
32 |       dy : m (Trace b)
33 |       dy = dx `bind` (traceDist . f)
34 |
35 | export
36 | MonadSample m => MonadSample (Traced m) where
37 |   random = MkTraced random (map singleton random)
38 |
39 | export
40 | MonadCond m => MonadCond (Traced m) where
41 |   score w = MkTraced (score w) (score w >> pure (scored w))
42 |   
43 | export
44 | MonadInfer m => MonadInfer (Traced m) where
45 |
46 | export
47 | hoistT : (forall x. m x -> m x) -> Traced m a -> Traced m a
48 | hoistT f (MkTraced m d) = MkTraced m (f d)
49 |
50 | ||| Discard the trace and supporting infrastructure.
51 | export
52 | marginal : Monad m => Traced m a -> m a
53 | marginal (MkTraced _ d) = map output d
54 |
55 | ||| A single step of the Trace Metropolis-Hastings algorithm.
56 | public export
57 | mhStep : MonadSample m => Traced m a -> Traced m a
58 | mhStep (MkTraced m d) = MkTraced m (d >>= mhTrans' m)
59 |
60 | ||| Full run of the Trace Metropolis-Hastings algorithm with a specified
61 | ||| number of steps. 
62 | public export
63 | mh : MonadSample m => (n : Nat) -> Traced m a -> m (Vect (S n) a)
64 | mh n (MkTraced mod d) = map (map output) (f n)
65 |   where
66 |     f : (n : Nat) -> m (Vect (S n) (Trace a))
67 |     f Z     = map (:: []) d
68 |     f (S k) = do
69 |           (x :: xs) <- f k
70 |           y <- mhTrans' mod x
71 |           pure (y :: x :: xs)