0 | module Control.Monad.Bayes.Traced.Common
3 | import Control.Monad.Writer
4 | import Control.Monad.Bayes.Interface
5 | import Control.Monad.Bayes.Weighted
6 | import Control.Monad.Bayes.Free
10 | record Trace (a : Type) where
13 | variables : List Double
17 | density : Log Double
21 | map f t = { output $= f } t
24 | Applicative Trace where
25 | pure x = MkTrace { variables = [], output = x, density = 1 }
28 | { variables = variables tf ++ variables tx,
29 | output = output tf (output tx),
30 | density = density tf * density tx
36 | let t' = f (t.output)
37 | in {variables := variables t ++ variables t', density := density t * density t'} t'
40 | singleton : Double -> Trace Double
41 | singleton u = MkTrace {variables = [u], output = u, density = toLogDomain 1.0}
46 | scored w = MkTrace {variables = [], output = (), density = w}
49 | bind : Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
53 | pure $
{variables := variables t1 ++ variables t2, density := density t1 * density t2} t2
57 | mhTrans : MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
58 | mhTrans mw t@(MkTrace {variables = us, output = x, density = p}) = do
60 | i <- discreteUniform (length us)
61 | case splitAt i us of
62 | (xs, _ :: ys) => pure $
xs ++ (!random :: ys)
63 | _ => assert_total $
idris_crash ("impossible event: trying to split " ++ show us ++ " at index " ++ show i)
64 | ((q, b), vs) <- runWriterT $
runWeighted $
Weighted.hoist (writerT . withPartialRandomness us') mw
65 | let ratio : Double = (exp . ln) $
min 1 ( (q * (Exp $
log $
cast (length us))) /
66 | (p * (Exp $
log $
cast (length vs)))
68 | accept <- bernoulli ratio
69 | pure $
if accept then MkTrace vs b q else t
73 | mhTrans' : MonadSample m => Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)
74 | mhTrans' m = mhTrans (Weighted.hoist (Free.hoist (pure . runIdentity)) m)