0 | module Control.Monad.Bayes.Interface
  1 |
  2 | import Control.Monad.Maybe
  3 | import Control.Monad.Reader
  4 | import Control.Monad.RWS
  5 | import Control.Monad.State
  6 | import Control.Monad.Trans
  7 | import Control.Monad.Writer
  8 | import public Data.List
  9 | import public Data.Vect
 10 | import System.Random
 11 |
 12 | import public Statistics.Distribution
 13 |
 14 | import public Numeric.Log
 15 |
 16 | %default total
 17 |
 18 | public export
 19 | interface Monad m => MonadSample m where
 20 |   ||| Draws a random value from Uniform(0,1)
 21 |   random : m Double
 22 |
 23 |   ||| Uniform(min, max)
 24 |   uniform : (min, max : Double) -> m Double
 25 |   uniform min max = map (gsl_uniform_cdf_inv min max) random
 26 |
 27 |   ||| Normal(mean, sd)
 28 |   normal : (mean, sd : Double) -> m Double
 29 |   normal m s      = map (gsl_normal_cdf_inv m s) random
 30 |
 31 |   ||| Gamma(shape, scale) -> m Double
 32 |   gamma : (a, b : Double) -> m Double
 33 |   gamma a b       = map (gsl_gamma_cdf_inv a b) random
 34 |
 35 |   ||| Beta(alpha, beta) -> m Double
 36 |   beta : (a, b : Double) -> m Double
 37 |   beta a b        = map (gsl_beta_cdf_inv a b) random
 38 |
 39 |   ||| Bernoulli(prob)
 40 |   bernoulli : (p : Double) -> m Bool
 41 |   bernoulli p     = map (< p) random
 42 |
 43 |   ||| Binomial(num trials, prob of each trial)
 44 |   binomial : (n : Nat) -> (p : Double) -> m Nat
 45 |   binomial n p = (pure . length . List.filter (== True)) !(sequence . replicate n $ bernoulli p)
 46 |
 47 |   ||| Categorical(probs)
 48 |   categorical : {n : Nat} -> Vect n Double -> m (Fin n)
 49 |   categorical ps = do
 50 |     r <- random
 51 |     let normalised_ps = map (/(sum ps)) ps
 52 |
 53 |         cmf : Double -> Nat -> List Double -> Maybe (Fin n)
 54 |         cmf acc idx (x :: xs) = let acc' = acc + x
 55 |                                 in  if acc' > r then natToFin idx n else cmf acc' (S idx) xs
 56 |         cmf acc idx []        = Nothing
 57 |
 58 |     case cmf 0 0 (toList normalised_ps) of
 59 |       Just i  => pure i
 60 |       Nothing => assert_total $ idris_crash $ "categorical: bad weights!" ++ show ps
 61 |
 62 |   ||| Log-categorical(log-probs)
 63 |   logCategorical : {n : _} -> Vect n (Log Double) -> m (Fin n)
 64 |   logCategorical logps = categorical (map (exp . ln) logps)
 65 |
 66 |   ||| Uniform-Discrete(values)
 67 |   uniformD : {n : Nat} -> Vect (S n) a -> m a
 68 |   uniformD xs = do
 69 |     idx <- categorical $ replicate (S n) (1 / cast n)
 70 |     pure (index idx xs)
 71 |
 72 |   ||| Dirichlet(concentrations)
 73 |   dirichlet : Vect n Double -> m (Vect n Double)
 74 |   dirichlet as = do
 75 |     xs <- sequence $ map (`gamma` 1) as
 76 |     let s  = sum xs
 77 |         ys = map (/ s) xs
 78 |     pure ys
 79 |
 80 |   ||| DiscUniform(range); should return Nat from 0 to (range - 1)
 81 |   discreteUniform : (range : Nat) -> m Nat
 82 |   discreteUniform range = do
 83 |         r <- random
 84 |         pure $ cast (floor (cast range * r))
 85 |
 86 |   ||| Draw from a discrete distribution using the probability mass function and a sequence of draws from Bernoulli.
 87 |   fromPMF : (pmf : Nat -> Double) -> m Nat
 88 |   fromPMF pmf = f 0 1
 89 |     where
 90 |       f : Nat -> Double -> m Nat
 91 |       f n marginal with (marginal < 0)
 92 |        _ | False = do
 93 |                 let prob = pmf n
 94 |                 b <- bernoulli (prob / marginal)
 95 |                 if b then pure n else assert_total f (n + 1) (marginal - prob)
 96 |        _ | True = assert_total $ idris_crash "fromPMF: total PMF above 1"
 97 |
 98 |   ||| Geometric(prob)
 99 |   geometric : (p : Double) -> m Nat
100 |   geometric p = fromPMF (gsl_geometric_pdf p)
101 |
102 |   ||| Hypergeometric(num elements of "type 1", num elements of "type 2", num samples)
103 |   hypergeometric : (n1, n2, t : Nat) -> m Nat
104 |   hypergeometric n1 n2 t = fromPMF (gsl_hypergeometric_pdf n1 n2 t)
105 |
106 |   ||| Poisson(λ)
107 |   poisson : (p : Double) -> m Nat
108 |   poisson p = fromPMF (gsl_poisson_pdf p)
109 |
110 | public export
111 | interface Monad m => MonadCond m where
112 |   ||| Record a likelihood. Note: when calling `score (Exp p)`, p must already be in the log-domain.
113 |   score : Log Double -> m ()
114 |
115 | export
116 | condition : MonadCond m => Bool -> m ()
117 | condition b = score $ if b then 1 else 0
118 |
119 | public export
120 | interface (MonadSample mMonadCond m) => MonadInfer m where
121 |
122 | -- MaybeT
123 | export
124 | MonadSample m => MonadSample (MaybeT m) where
125 |   random = lift random
126 |   bernoulli = lift . bernoulli
127 | export
128 | MonadCond m => MonadCond (MaybeT m) where
129 |   score = lift . score
130 | export
131 | MonadInfer m => MonadInfer (MaybeT m) where
132 |
133 | -- ReaderT
134 | export
135 | MonadSample m => MonadSample (ReaderT r m) where
136 |   random = lift random
137 |   bernoulli = lift . bernoulli
138 | export
139 | MonadCond m => MonadCond (ReaderT r m) where
140 |   score = lift . score
141 | export
142 | MonadInfer m => MonadInfer (ReaderT r m) where
143 |
144 | -- WriterT
145 | export
146 | MonadSample m => MonadSample (WriterT w m) where
147 |   random = lift random
148 |   bernoulli = lift . bernoulli
149 |   categorical = lift . categorical
150 | export
151 | MonadCond m => MonadCond (WriterT w m) where
152 |   score = lift . score
153 | export
154 | MonadInfer m => MonadInfer (WriterT w m) where
155 |
156 | -- StateT
157 | export
158 | MonadSample m => MonadSample (StateT s m) where
159 |   random = lift random
160 |   bernoulli = lift . bernoulli
161 |   categorical = lift . categorical
162 | export
163 | MonadCond m => MonadCond (StateT s m) where
164 |   score = lift . score
165 | export
166 | MonadInfer m => MonadInfer (StateT s m) where
167 |
168 | -- RWST
169 | export
170 | MonadSample m => MonadSample (RWST r w s m) where
171 |   random = lift random
172 |   bernoulli = lift . bernoulli
173 | export
174 | MonadCond m => MonadCond (RWST r w s m) where
175 |   score = lift . score
176 | export
177 | MonadInfer m => MonadInfer (RWST r w s m) where
178 |