0 | module Control.Monad.Bayes.Traced.Common
 1 |
 2 | import Numeric.Log
 3 | import Control.Monad.Writer
 4 | import Control.Monad.Bayes.Interface
 5 | import Control.Monad.Bayes.Weighted
 6 | import Control.Monad.Bayes.Free
 7 |
 8 | ||| Collection of random variables sampled during the program's execution.
 9 | public export
10 | record Trace (a : Type) where
11 |   constructor MkTrace 
12 |   -- | Sequence of random variables sampled during the program's execution.
13 |   variables : List Double 
14 |   -- |
15 |   output    : a
16 |   -- | The probability of observing this particular sequence.
17 |   density   : Log Double     
18 |
19 | export
20 | Functor Trace where
21 |   map f t = { output $= f } t
22 |
23 | export
24 | Applicative Trace where
25 |   pure x    = MkTrace { variables = [], output = x, density = 1 }
26 |   tf <*> tx =
27 |     MkTrace
28 |       { variables = variables tf ++ variables tx,
29 |         output    = output tf (output tx),
30 |         density   = density tf * density tx
31 |       }
32 |
33 | export
34 | Monad Trace where
35 |   t >>= f = 
36 |     let t' = f (t.output)
37 |     in  {variables := variables t ++ variables t', density := density t * density t'} t' 
38 |
39 | export
40 | singleton : Double -> Trace Double
41 | singleton u = MkTrace {variables = [u], output = u, density = toLogDomain 1.0}
42 |
43 | export
44 | scored : Log Double
45 |       -> Trace ()
46 | scored w = MkTrace {variables = [], output = (), density = w}
47 |
48 | export
49 | bind : Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
50 | bind dx f = do
51 |   t1 <- dx
52 |   t2 <- f (output t1)
53 |   pure $ {variables := variables t1 ++ variables t2, density := density t1 * density t2} t2
54 |
55 | ||| A single Metropolis-corrected transition of single-site Trace MCMC.
56 | export
57 | mhTrans : MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
58 | mhTrans mw t@(MkTrace {variables = us, output = x, density = p}) = do
59 |   us' <- do
60 |     i <- discreteUniform (length us) 
61 |     case splitAt i us of
62 |       (xs, _ :: ys) => pure $ xs ++ (!random :: ys)
63 |       _             => assert_total $ idris_crash ("impossible event: trying to split " ++ show us ++ " at index " ++ show i)
64 |   ((q, b), vs) <- runWriterT $ runWeighted $ Weighted.hoist (writerT . withPartialRandomness us') mw
65 |   let ratio : Double = (exp . ln) $ min 1 ( (q * (Exp $ log $ cast (length us))) / 
66 |                                             (p * (Exp $ log $ cast (length vs)))
67 |                                           )
68 |   accept <- bernoulli ratio
69 |   pure $ if accept then MkTrace vs b q else t
70 |
71 | ||| A variant of 'mhTrans' with an external sampling monad.
72 | export
73 | mhTrans' : MonadSample m => Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
74 | mhTrans' m = mhTrans (Weighted.hoist (Free.hoist (pure . runIdentity)) m)
75 |