0 | module Control.Monad.Bayes.Free
 1 |
 2 | import Control.Monad.Identity
 3 | import Control.Monad.Trans
 4 | import Control.Monad.State
 5 | import Control.Monad.Writer
 6 |
 7 | import Control.Monad.Bayes.Interface
 8 | import Control.Monad.Free
 9 | import public Control.Monad.Trans.Free.Church
10 |
11 | ||| Random sampling functor
12 | public export
13 | data SamF a = Random (Double -> a)
14 |
15 | public export
16 | Functor SamF where
17 |   map f (Random k) = Random (f . k)
18 |
19 | ||| Free monad transformer over random sampling.
20 | ||| Uses the Church-encoded version of the free monad for efficiency.
21 | public export
22 | FreeSampler : (m : Type -> Type) -> (a : Type) -> Type
23 | FreeSampler = FT SamF
24 |
25 | export
26 | (Monad m, MonadFree SamF (FreeSampler m)) => MonadSample (FreeSampler m) where
27 |   random = liftF $ Random id
28 |
29 | ||| Hoist 'FreeSampler' through a monad transform.
30 | export
31 | hoist : (Monad m, Monad n) => (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
32 | hoist f m = hoistFT f m
33 |
34 | ||| Execute random sampling in the transformed monad.
35 | export
36 | interpret : MonadSample m => FreeSampler m a -> m a
37 | interpret c = iterT (\(Random k) => random >>= k) c
38 |
39 | ||| Execute computation with supplied values for random choices.
40 | export
41 | withRandomness : Monad m => List Double -> FreeSampler m a -> m a
42 | withRandomness randomness = evalStateT randomness . iterTM f
43 |   where f : MonadState (List Double) n => SamF (n b) -> n b
44 |         f (Random k) = do
45 |           xs <- the (n (List Double)) get
46 |           case xs of
47 |             []      => assert_total $ idris_crash ("withRandomness: randomness too short!")
48 |             y :: ys => put ys >> k y
49 |
50 | MonadTrans (\m => StateT (List Double) (WriterT (List Double) m)) where
51 |   lift = lift . lift
52 |
53 | ||| Execute computation with supplied values for a subset of random choices.
54 | ||| Return the output value and a record of all random choices used, whether
55 | ||| taken as input or drawn using the transformed monad.
56 | public export
57 | withPartialRandomness : MonadSample m => List Double -> FreeSampler m a -> m (a, List Double)
58 | withPartialRandomness randomness k = 
59 |   runWriterT $ 
60 |     evalStateT randomness $ 
61 |       iterTM {t = \m => StateT (List Double) (WriterT (List Double) m) } f k
62 |   where f : (MonadSample n, MonadWriter (List Double) n, MonadState (List Double) n) => SamF (n a) -> n a
63 |         f (Random k) = do
64 |           xs <- the (n (List Double)) get
65 |           x <- case xs of
66 |             []      => random
67 |             y :: ys => put ys >> pure y
68 |           tell [x]
69 |           k x
70 |
71 | ||| Like `withPartialRandomness`, but use an arbitrary sampling monad.
72 | public export
73 | runWith : MonadSample m => List Double -> FreeSampler Identity a -> m (a, List Double)
74 | runWith randomness = withPartialRandomness randomness . hoist {n=m} (pure . runIdentity)
75 |