0 | module Control.Monad.Bayes.Traced.Dynamic
 1 |
 2 | import public Data.Vect
 3 | import Control.Monad.Bayes.Interface
 4 | import Control.Monad.Bayes.Traced.Common
 5 | import Control.Monad.Bayes.Weighted
 6 | import Control.Monad.Bayes.Free
 7 | import Control.Monad.Free
 8 |
 9 | ||| A tracing monad where only a subset of random choices are traced and this subset can be adjusted dynamically.
10 | public export
11 | record Traced (m : Type -> Type) (a : Type) where 
12 |   constructor MkTraced
13 |   runTraced : m (Weighted (FreeSampler m) a, Trace a)
14 |
15 | export
16 | pushM : Monad m => m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a
17 | pushM = join . lift . lift
18 |
19 | export
20 | Monad m => Functor (Traced m) where
21 |   map f (MkTraced c) = MkTraced $ do
22 |     (m, t) <- c
23 |     let m' = map f m
24 |         t' = map f t
25 |     pure (m', t')
26 |
27 | export
28 | Monad m => Applicative (Traced m) where
29 |   pure x = MkTraced $ pure (pure x, pure x)
30 |   (MkTraced cf) <*> (MkTraced cx) = MkTraced $ do
31 |     (mf, tf) <- cf
32 |     (mx, tx) <- cx
33 |     pure (mf <*> mx, tf <*> tx)
34 |
35 | export
36 | Monad m => Monad (Traced m) where
37 |   (MkTraced cx) >>= f = MkTraced $ do
38 |     (mx, tx) <- cx
39 |     let m = mx >>= pushM . map fst . runTraced . f
40 |     t <- pure tx `bind` (map snd . runTraced . f)
41 |     pure (m, t)
42 |
43 | export
44 | MonadTrans Traced where
45 |   lift m = MkTraced $ map (\a => (lift $ lift m, pure a)) m
46 |
47 | export
48 | MonadSample m => MonadSample (Traced m) where
49 |   random = MkTraced $ map (\r => (random, singleton r)) random
50 |
51 | export
52 | MonadCond m => MonadCond (Traced m) where
53 |   score w = MkTraced $ map (score w,) (score w >> pure (scored w))
54 |
55 | export
56 | MonadInfer m => MonadInfer (Traced m) where
57 |
58 | export
59 | hoistT : (forall x. m x -> m x) -> Traced m a -> Traced m a
60 | hoistT f (MkTraced c) = MkTraced (f c)
61 |
62 | ||| Discard the trace and supporting infrastructure.
63 | export
64 | marginal : Monad m => Traced m a -> m a
65 | marginal (MkTraced c) = map (output . snd) c
66 |
67 | ||| Freeze all traced random choices to their current values and stop tracing them.
68 | export
69 | freeze : Monad m => Traced m a -> Traced m a
70 | freeze (MkTraced c) = MkTraced $ do
71 |   (_, t) <- c
72 |   let x = t.output
73 |   pure (pure x, pure x)
74 |
75 | ||| A single step of the Trace Metropolis-Hastings algorithm.
76 | export
77 | mhStep : MonadSample m => Traced m a -> Traced m a
78 | mhStep (MkTraced c) = MkTraced $ do
79 |   (m, t) <- c
80 |   t' <- mhTrans m t
81 |   pure (m, t')
82 |
83 | -- | Full run of the Trace Metropolis-Hastings algorithm with a specified
84 | -- number of steps.
85 | public export
86 | mh : MonadSample m => (n : Nat) -> Traced m a -> m (Vect (S n) a)
87 | mh n (MkTraced c) = do
88 |   (mod, tr) <- c
89 |   let f : (n : Nat) -> m (Vect (S n) (Trace a))
90 |       f Z     = pure [tr]
91 |       f (S k) = do  (x :: xs) <- f k
92 |                     y <- mhTrans mod x
93 |                     pure (y :: x :: xs)
94 |   map (map output) (f n)
95 |