0 | module Control.Monad.Bayes.Interface
2 | import Control.Monad.Maybe
3 | import Control.Monad.Reader
4 | import Control.Monad.RWS
5 | import Control.Monad.State
6 | import Control.Monad.Trans
7 | import Control.Monad.Writer
8 | import public Data.List
9 | import public Data.Vect
10 | import System.Random
12 | import public Statistics.Distribution
14 | import public Numeric.Log
19 | interface Monad m => MonadSample m where
24 | uniform : (min, max : Double) -> m Double
25 | uniform min max = map (gsl_uniform_cdf_inv min max) random
28 | normal : (mean, sd : Double) -> m Double
29 | normal m s = map (gsl_normal_cdf_inv m s) random
32 | gamma : (a, b : Double) -> m Double
33 | gamma a b = map (gsl_gamma_cdf_inv a b) random
36 | beta : (a, b : Double) -> m Double
37 | beta a b = map (gsl_beta_cdf_inv a b) random
40 | bernoulli : (p : Double) -> m Bool
41 | bernoulli p = map (< p) random
44 | binomial : (n : Nat) -> (p : Double) -> m Nat
45 | binomial n p = (pure . length . List.filter (== True)) !(sequence . replicate n $
bernoulli p)
48 | categorical : {n : Nat} -> Vect n Double -> m (Fin n)
51 | let normalised_ps = map (/(sum ps)) ps
53 | cmf : Double -> Nat -> List Double -> Maybe (Fin n)
54 | cmf acc idx (x :: xs) = let acc' = acc + x
55 | in if acc' > r then natToFin idx n else cmf acc' (S idx) xs
56 | cmf acc idx [] = Nothing
58 | case cmf 0 0 (toList normalised_ps) of
60 | Nothing => assert_total $
idris_crash $
"categorical: bad weights!" ++ show ps
63 | logCategorical : {n : _} -> Vect n (Log Double) -> m (Fin n)
64 | logCategorical logps = categorical (map (exp . ln) logps)
67 | uniformD : {n : Nat} -> Vect (S n) a -> m a
69 | idx <- categorical $
replicate (S n) (1 / cast n)
73 | dirichlet : Vect n Double -> m (Vect n Double)
75 | xs <- sequence $
map (`gamma` 1) as
81 | discreteUniform : (range : Nat) -> m Nat
82 | discreteUniform range = do
84 | pure $
cast (floor (cast range * r))
87 | fromPMF : (pmf : Nat -> Double) -> m Nat
90 | f : Nat -> Double -> m Nat
91 | f n marginal with (marginal < 0)
94 | b <- bernoulli (prob / marginal)
95 | if b then pure n else assert_total f (n + 1) (marginal - prob)
96 | _ | True = assert_total $
idris_crash "fromPMF: total PMF above 1"
99 | geometric : (p : Double) -> m Nat
100 | geometric p = fromPMF (gsl_geometric_pdf p)
103 | hypergeometric : (n1, n2, t : Nat) -> m Nat
104 | hypergeometric n1 n2 t = fromPMF (gsl_hypergeometric_pdf n1 n2 t)
107 | poisson : (p : Double) -> m Nat
108 | poisson p = fromPMF (gsl_poisson_pdf p)
111 | interface Monad m => MonadCond m where
113 | score : Log Double -> m ()
116 | condition : MonadCond m => Bool -> m ()
117 | condition b = score $
if b then 1 else 0
120 | interface (MonadSample m,
MonadCond m) => MonadInfer m where
124 | MonadSample m => MonadSample (MaybeT m) where
125 | random = lift random
126 | bernoulli = lift . bernoulli
128 | MonadCond m => MonadCond (MaybeT m) where
129 | score = lift . score
131 | MonadInfer m => MonadInfer (MaybeT m) where
135 | MonadSample m => MonadSample (ReaderT r m) where
136 | random = lift random
137 | bernoulli = lift . bernoulli
139 | MonadCond m => MonadCond (ReaderT r m) where
140 | score = lift . score
142 | MonadInfer m => MonadInfer (ReaderT r m) where
146 | MonadSample m => MonadSample (WriterT w m) where
147 | random = lift random
148 | bernoulli = lift . bernoulli
149 | categorical = lift . categorical
151 | MonadCond m => MonadCond (WriterT w m) where
152 | score = lift . score
154 | MonadInfer m => MonadInfer (WriterT w m) where
158 | MonadSample m => MonadSample (StateT s m) where
159 | random = lift random
160 | bernoulli = lift . bernoulli
161 | categorical = lift . categorical
163 | MonadCond m => MonadCond (StateT s m) where
164 | score = lift . score
166 | MonadInfer m => MonadInfer (StateT s m) where
170 | MonadSample m => MonadSample (RWST r w s m) where
171 | random = lift random
172 | bernoulli = lift . bernoulli
174 | MonadCond m => MonadCond (RWST r w s m) where
175 | score = lift . score
177 | MonadInfer m => MonadInfer (RWST r w s m) where