0 | module Control.Monad.Bayes.Sequential
 1 |
 2 | import Control.Monad.Trans
 3 | import Control.Monad.Bayes.Interface
 4 |
 5 | ||| Represents a computation that can be suspended at certain points.
 6 | -- The intermediate monadic effects can be extracted, which is particularly
 7 | -- useful for implementation of Sequential Monte Carlo related methods.
 8 | -- All the probabilistic effects are lifted from the transformed monad, but
 9 | -- also `suspend` is inserted after each `score`.
10 | export
11 | data Sequential : (m : Type -> Type) -> (a : Type) -> Type where
12 |   MkSeq : Inf (m (Either a (Sequential m a))) -> Sequential m a
13 |
14 | export
15 | runSeq : Sequential m a -> m (Either a (Sequential m a))
16 | runSeq (MkSeq m) = m
17 |
18 | mutual
19 |   export
20 |   Monad m => Functor (Sequential m) where
21 |     map f (MkSeq mx) = MkSeq (do
22 |       x <- mx
23 |       case x of Left l  => pure (Left $ f l)
24 |                 Right r => pure (Right $ map f r))
25 |   export
26 |   Monad m => Applicative (Sequential m) where
27 |     pure x = MkSeq (pure (Left x))
28 |     mf <*> mv = do
29 |       f' <- mf
30 |       v' <- mv
31 |       pure $ f' v'
32 |
33 |   export covering
34 |   Monad m => Monad (Sequential m) where
35 |     (>>=) (MkSeq mx) f = MkSeq $ do
36 |       x <- mx
37 |       case x of
38 |         Left l    => runSeq (f l)
39 |         Right seq => pure (Right ((assert_total (>>=)) seq  f))
40 |
41 |   export
42 |   MonadTrans Sequential where
43 |     lift mx = MkSeq (map Left mx)
44 |
45 | export
46 | suspend : Monad m => Sequential m ()
47 | suspend = MkSeq (pure (Right (pure ())))
48 |
49 | export
50 | MonadSample m => MonadSample (Sequential m) where
51 |   random = lift random
52 |
53 | ||| Execution is suspended after each 'score'.
54 | export
55 | MonadCond m => MonadCond (Sequential m) where
56 |   score w = lift (score w) >> suspend
57 |
58 | export
59 | MonadInfer m => MonadInfer (Sequential m) where
60 |
61 | export
62 | ||| Execute to the next suspension point. If the computation is finished, do nothing.
63 | advance : Monad m => Sequential m a -> Sequential m a
64 | advance (MkSeq m) = MkSeq (m >>= either ( pure . Left ) runSeq )
65 |
66 | export
67 | ||| Remove the remaining suspension points.
68 | finish : Monad m => Sequential m a -> m a
69 | finish (MkSeq m) = (m >>= either pure finish)
70 |
71 | export
72 | ||| Transform the inner monad. This operation only applies to computation up to the first suspension.
73 | hoistFirst : (forall x. m x -> m x) -> Sequential m a -> Sequential m a
74 | hoistFirst tau (MkSeq m) = MkSeq (tau m)
75 |
76 | ||| Apply a function a given number of times.
77 | export
78 | composeCopies : Nat -> (a -> a) -> (a -> a)
79 | composeCopies k f = foldr (.) id (List.replicate k f)
80 |
81 | ||| Sequential importance sampling. Applies a given transformation after each time step.
82 | export
83 | sis :
84 |   Monad m =>
85 |   -- | transformation
86 |   (forall x. m x -> m x) ->
87 |   -- | number of time steps
88 |   (n_timesteps : Nat) ->
89 |   Sequential m a ->
90 |   m a
91 | sis f k = finish . composeCopies k (advance . hoistFirst f)
92 |