0 | ||| Weighted is an instance of MonadCond. Apply a MonadSample transformer to
 1 | ||| obtain a MonadInfer that can execute probabilistic models.
 2 |
 3 | module Control.Monad.Bayes.Weighted
 4 |
 5 | import public Control.Monad.State
 6 |
 7 | import Control.Monad.Bayes.Interface
 8 |
 9 | import public Numeric.Log
10 |
11 | ||| Execute the program using the prior distribution, while accumulating likelihood.
12 | public export
13 | record Weighted (m : Type -> Type) (a : Type) where
14 |   constructor MkWeighted
15 |   runWeighted' : StateT (Log Double) m a 
16 |
17 | ||| Obtain an explicit value of the likelihood for a given value
18 | public export
19 | runWeighted : Weighted m a -> m (Log Double, a)
20 | runWeighted (MkWeighted m) = runStateT (toLogDomain 1.0) m 
21 |
22 | export
23 | MonadTrans Weighted where
24 |   lift = MkWeighted . lift
25 |
26 | export
27 | Functor m => Functor (Weighted m) where
28 |   map f (MkWeighted s) =  (MkWeighted $ map f s) 
29 |
30 | export
31 | Monad m => Applicative (Weighted m) where
32 |   pure x = MkWeighted (pure x)
33 |   (MkWeighted mf) <*> (MkWeighted ma) = MkWeighted (mf <*> ma) 
34 |   
35 | export
36 | Monad m => Monad (Weighted m) where
37 |   (MkWeighted mx) >>= k = MkWeighted (mx >>= (runWeighted' . k))
38 |
39 | export
40 | MonadSample m => MonadSample (Weighted m) where
41 |   random = lift random
42 |
43 | export
44 | Monad m => MonadCond (Weighted m) where
45 |   score w = MkWeighted (modify (* w))
46 |
47 | export
48 | MonadSample m => MonadInfer (Weighted m) where
49 |
50 | ||| Compute the sample and discard the weight.
51 | ||| This operation introduces bias.
52 | export
53 | prior : Functor m => Weighted m a -> m a
54 | prior = map snd . runWeighted
55 |
56 | ||| Compute the weight and discard the sample.
57 | export
58 | extractWeight : Functor m => Weighted m a -> m (Log Double)
59 | extractWeight = map fst . runWeighted
60 |
61 | ||| Embed a random variable with explicitly given likelihood
62 | export
63 | withWeight : Monad m => m (Log Double, a) -> Weighted m a
64 | withWeight m = MkWeighted $ do
65 |   (w, x) <- lift m
66 |   modify (* w)
67 |   pure x
68 |
69 | ||| Combine weights from two different levels.
70 | export
71 | flatten : Monad m => Weighted (Weighted m) a -> Weighted m a
72 | flatten m = withWeight $ (\(p, (q, x)) => (p * q, x)) <$> runWeighted (runWeighted m)
73 |
74 | ||| Use the weight as a factor in the transformed monad.
75 | export
76 | applyWeight : MonadCond m => Weighted m a -> m a
77 | applyWeight m = do
78 |   (w, x) <- runWeighted m
79 |   score w
80 |   pure x
81 |
82 | ||| Apply a transformation to the transformed monad.
83 | export
84 | hoist : (forall x. m x -> n x) -> Weighted m a -> Weighted n a
85 | hoist phi (MkWeighted w) =  MkWeighted $ mapStateT phi w
86 |