0 | module Control.Monad.Bayes.Traced.Static
 1 |
 2 | import public Data.Vect
 3 | import Control.Monad.Bayes.Interface
 4 | import Control.Monad.Bayes.Traced.Common
 5 | import Control.Monad.Bayes.Weighted
 6 | import Control.Monad.Bayes.Free
 7 | import Control.Monad.Free
 8 |
 9 | ||| A tracing monad where only a subset of random choices are traced.
10 | ||| The random choices that are not to be traced should be lifted from the transformed monad.
11 | public export
12 | record Traced (m : Type -> Type) (a : Type) where 
13 |   constructor MkTraced
14 |   model     : Weighted (FreeSampler m) a
15 |   traceDist : m (Trace a)
16 |
17 | public export
18 | Monad m => Functor (Traced m) where
19 |   map f (MkTraced m d) = MkTraced (map f m) (map (map f) d)
20 |
21 | public export
22 | Monad m => Applicative (Traced m) where
23 |   pure x = MkTraced (pure x) (pure (pure x))
24 |   (MkTraced mf df) <*> (MkTraced mx dx) = MkTraced (mf <*> mx) ((map (<*>) df) <*> dx) 
25 |
26 | public export
27 | Monad m => Monad (Traced m) where
28 |   (MkTraced mx dx) >>= f = MkTraced my dy
29 |     where
30 |       my : Weighted (FreeSampler m) b
31 |       my =  mx >>= (model . f)
32 |       dy : m (Trace b)
33 |       dy = dx `bind` (traceDist . f)
34 |
35 | public export
36 | MonadTrans Traced where
37 |   lift m = MkTraced (lift $ lift m) (map pure m)
38 |
39 | public export
40 | MonadSample m => MonadSample (Traced m) where
41 |   random = MkTraced random (map singleton random)
42 |
43 | public export
44 | MonadCond m => MonadCond (Traced m) where
45 |   score w = MkTraced (score w) (score w >> pure (scored w))
46 |   
47 | public export
48 | MonadInfer m => MonadInfer (Traced m) where
49 |
50 | export
51 | hoistT : (forall x. m x -> m x) -> Traced m a -> Traced m a
52 | hoistT f (MkTraced m d) = MkTraced m (f d)
53 |
54 | ||| Discard the trace and supporting infrastructure.
55 | export
56 | marginal : Monad m => Traced m a -> m a
57 | marginal (MkTraced _ d) = map output d
58 |
59 | ||| A single step of the Trace Metropolis-Hastings algorithm.
60 | public export
61 | mhStep : MonadSample m => Traced m a -> Traced m a
62 | mhStep (MkTraced m d) = MkTraced m (d >>= mhTrans m)
63 |
64 | ||| Full run of the Trace Metropolis-Hastings algorithm with a specified
65 | ||| number of steps. Newest samples are at the head of the list.
66 | public export
67 | mh : MonadSample m => (n : Nat) -> Traced m a -> m (Vect (S n) a)
68 | mh n (MkTraced mod d) = map (map output) (f n)
69 |   where
70 |     f : (n : Nat) -> m (Vect (S n) (Trace a))
71 |     f Z     = map (:: []) d
72 |     f (S k) = do
73 |           (x :: xs) <- f k
74 |           y <- mhTrans mod x
75 |           pure (y :: x :: xs)