0 | module Control.Monad.Bayes.Sequential
2 | import Control.Monad.Trans
3 | import Control.Monad.Bayes.Interface
11 | data Sequential : (m : Type -> Type) -> (a : Type) -> Type where
12 | MkSeq : Inf (m (Either a (Sequential m a))) -> Sequential m a
15 | runSeq : Sequential m a -> m (Either a (Sequential m a))
16 | runSeq (MkSeq m) = m
20 | Monad m => Functor (Sequential m) where
21 | map f (MkSeq mx) = MkSeq (do
23 | case x of Left l => pure (Left $
f l)
24 | Right r => pure (Right $
map f r))
26 | Monad m => Applicative (Sequential m) where
27 | pure x = MkSeq (pure (Left x))
34 | Monad m => Monad (Sequential m) where
35 | (>>=) (MkSeq mx) f = MkSeq $
do
38 | Left l => runSeq (f l)
39 | Right seq => pure (Right ((assert_total (>>=)) seq f))
42 | MonadTrans Sequential where
43 | lift mx = MkSeq (map Left mx)
46 | suspend : Monad m => Sequential m ()
47 | suspend = MkSeq (pure (Right (pure ())))
50 | MonadSample m => MonadSample (Sequential m) where
51 | random = lift random
55 | MonadCond m => MonadCond (Sequential m) where
56 | score w = lift (score w) >> suspend
59 | MonadInfer m => MonadInfer (Sequential m) where
63 | advance : Monad m => Sequential m a -> Sequential m a
64 | advance (MkSeq m) = MkSeq (m >>= either ( pure . Left ) runSeq )
68 | finish : Monad m => Sequential m a -> m a
69 | finish (MkSeq m) = (m >>= either pure finish)
73 | hoistFirst : (forall x. m x -> m x) -> Sequential m a -> Sequential m a
74 | hoistFirst tau (MkSeq m) = MkSeq (tau m)
78 | composeCopies : Nat -> (a -> a) -> (a -> a)
79 | composeCopies k f = foldr (.) id (List.replicate k f)
86 | (forall x. m x -> m x) ->
88 | (n_timesteps : Nat) ->
91 | sis f k = finish . composeCopies k (advance . hoistFirst f)