3 | module Control.Monad.Bayes.Weighted
5 | import public Control.Monad.State
7 | import Control.Monad.Bayes.Interface
9 | import public Numeric.Log
13 | record Weighted (m : Type -> Type) (a : Type) where
14 | constructor MkWeighted
15 | runWeighted' : StateT (Log Double) m a
19 | runWeighted : Weighted m a -> m (Log Double, a)
20 | runWeighted (MkWeighted m) = runStateT (toLogDomain 1.0) m
23 | MonadTrans Weighted where
24 | lift = MkWeighted . lift
27 | Functor m => Functor (Weighted m) where
28 | map f (MkWeighted s) = (MkWeighted $
map f s)
31 | Monad m => Applicative (Weighted m) where
32 | pure x = MkWeighted (pure x)
33 | (MkWeighted mf) <*> (MkWeighted ma) = MkWeighted (mf <*> ma)
36 | Monad m => Monad (Weighted m) where
37 | (MkWeighted mx) >>= k = MkWeighted (mx >>= (runWeighted' . k))
40 | MonadSample m => MonadSample (Weighted m) where
41 | random = lift random
44 | Monad m => MonadCond (Weighted m) where
45 | score w = MkWeighted (modify (* w))
48 | MonadSample m => MonadInfer (Weighted m) where
53 | prior : Functor m => Weighted m a -> m a
54 | prior = map snd . runWeighted
58 | extractWeight : Functor m => Weighted m a -> m (Log Double)
59 | extractWeight = map fst . runWeighted
63 | withWeight : Monad m => m (Log Double, a) -> Weighted m a
64 | withWeight m = MkWeighted $
do
71 | flatten : Monad m => Weighted (Weighted m) a -> Weighted m a
72 | flatten m = withWeight $
(\(p, (q, x)) => (p * q, x)) <$> runWeighted (runWeighted m)
76 | applyWeight : MonadCond m => Weighted m a -> m a
78 | (w, x) <- runWeighted m
84 | hoist : (forall x. m x -> n x) -> Weighted m a -> Weighted n a
85 | hoist phi (MkWeighted w) = MkWeighted $
mapStateT phi w