0 | module Control.Monad.Bayes.Population
2 | import public Control.Monad.Bayes.Weighted
3 | import Control.Monad.Bayes.Interface
4 | import Control.Monad.Bayes.Sampler
5 | import Control.Monad.Trans
12 | record ListT (m : Type -> Type) (a : Type) where
14 | runListT : m (List a)
16 | mapListT : (m (List a) -> n (List b)) -> ListT m a -> ListT n b
17 | mapListT f m = MkListT $
f (runListT m)
20 | Functor m => Functor (ListT m) where
21 | map f = mapListT $
map $
map f
24 | Applicative m => Applicative (ListT m) where
25 | pure x = MkListT (pure [x])
26 | f <*> v = MkListT $
(<*>) <$> runListT f <*> runListT v
29 | Monad m => Monad (ListT m) where
30 | m >>= k = MkListT $
do
32 | b <- (sequence . map (runListT . k)) a
36 | MonadTrans ListT where
37 | lift = MkListT . map List.singleton
40 | MonadSample m => MonadSample (ListT m) where
41 | random = lift random
42 | bernoulli = lift . bernoulli
43 | categorical = lift . categorical
46 | MonadCond m => MonadCond (ListT m) where
47 | score = lift . score
50 | MonadInfer m => MonadInfer (ListT m) where
54 | record Population (m : Type -> Type) (a : Type) where
55 | constructor MkPopulation
56 | runPopulation' : Weighted (ListT m) a
59 | Functor m => Functor (Population m) where
60 | map f (MkPopulation mx) = MkPopulation (map f mx)
63 | Monad m => Applicative (Population m) where
64 | pure = MkPopulation . pure
65 | (MkPopulation mf) <*> (MkPopulation ma) = MkPopulation (mf <*> ma)
68 | Monad m => Monad (Population m) where
69 | (MkPopulation mx) >>= k = MkPopulation (mx >>= (runPopulation' . k))
72 | MonadTrans Population where
73 | lift = MkPopulation . lift . lift
76 | MonadSample m => MonadSample (Population m) where
77 | random = lift random
80 | Monad m => MonadCond (Population m) where
81 | score w = MkPopulation $
score w
84 | MonadSample m => MonadInfer (Population m) where
88 | runPopulation : Population m a -> m (List (Log Double, a))
89 | runPopulation (MkPopulation m) = (runListT . runWeighted) m
93 | explicitPopulation : Functor m => Population m a -> m (List (Double, a))
94 | explicitPopulation = map (map (\(log_w, a) => (fromLogDomain log_w, a))) . runPopulation
98 | fromWeightedList : Monad m => m (List (Log Double, a)) -> Population m a
99 | fromWeightedList = MkPopulation . withWeight . MkListT
105 | (forall x. m1 x -> m2 x) ->
108 | hoist f = fromWeightedList . f . runPopulation
114 | spawn : (isMonad : Monad m) => Nat -> Population m ()
115 | spawn n = fromWeightedList $
pure $
replicate n (toLogDomain (1.0 / cast n), ())
121 | ({k : Nat} -> Vect k Double -> m (List (Fin k))) ->
124 | resampleGeneric resampler pop = fromWeightedList $
do
125 | particles <- runPopulation pop
126 | let (log_ws, xs) : (Vect (length particles) (Log Double), Vect (length particles) a)
127 | = unzip (fromList particles)
128 | z : Log Double = Numeric.Log.sum log_ws
129 | if fromLogDomain z > 0
131 | let weights : Vect (length particles) Double
132 | = map (fromLogDomain . (/ z)) log_ws
133 | ancestors <- resampler weights
134 | let offsprings : List a
135 | = map (\idx => index idx xs) ancestors
136 | pure $
map (z / (toLogDomain $
length particles), ) offsprings
143 | systematic : {n : Nat} -> Double -> Vect n Double -> List (Fin n)
144 | systematic {n = Z} u Nil = Nil
145 | systematic {n = S k} u ws =
147 | prob : Maybe (Fin (S k)) -> Double
148 | prob (Just idx) = index idx ws
149 | prob Nothing = index last ws
152 | inc = 1 / cast (S k)
154 | f : Nat -> Double -> Nat -> Double -> List Nat -> List Nat
156 | if i == S k then acc else
158 | then f (S i) (v + inc) j q ((minus j 1) :: acc)
159 | else f i v (S j) (q + prob (natToFin j (S k))) acc
161 | particle_idxs : List (Fin (S k))
162 | particle_idxs = map (\nat => fromMaybe FZ (natToFin nat (S k)))
163 | (f Z (u / cast (S k)) Z 0.0 [])
170 | resampleSystematic :
174 | resampleSystematic = resampleGeneric (\ws => (`systematic` ws) <$> random)
180 | stratified : MonadSample m => {n : Nat} -> Vect n Double -> m (List (Fin n))
181 | stratified {n = Z} Nil = pure Nil
182 | stratified {n = S k} weights = do
183 | dithers <- sequence (Vect.replicate (S k) (uniform 0.0 1.0))
185 | let vect_range : (n : Nat) -> Vect (S n) Nat
186 | vect_range n = reverse (range_bwd n)
187 | where range_bwd : (n : Nat) -> Vect (S n) Nat
188 | range_bwd (S k) = (S k) :: range_bwd k
189 | range_bwd Z = Z :: Nil
191 | positions : Vect (S k) Double
192 | positions = map (/cast (S k)) $
zipWith (+) dithers (map cast (vect_range k))
194 | cumulativeSum : Vect (S (S k)) Double
195 | cumulativeSum = scanl (+) 0.0 weights
197 | coalg : (Nat, Nat) -> Maybe (Maybe (Fin (S (S k))), (Nat, Nat))
198 | coalg (i, j) with (natToFin i (S k), natToFin j (S (S k)))
199 | _ | (Just i_fin, Just j_fin) with (index i_fin positions < index j_fin cumulativeSum)
200 | _ | True = Just (Just j_fin, (i + 1, j))
201 | _ | False = Just (Nothing, (i, j + 1))
202 | _ | otherwise = Nothing
204 | fin_pred : Fin (S (S k)) -> Fin (S k)
205 | fin_pred (FS k) = k
208 | particle_idxs : List (Fin (S k))
209 | particle_idxs = map (fin_pred) (catMaybes $
unfoldr coalg (0,0))
216 | resampleStratified :
220 | resampleStratified = resampleGeneric stratified
226 | multinomial : MonadSample m => {n : Nat} -> Vect n Double -> m (List (Fin n))
227 | multinomial ws = sequence $
replicate n (categorical ws)
232 | resampleMultinomial :
236 | resampleMultinomial = resampleGeneric multinomial
244 | Population (Weighted m) a
245 | extractEvidence pop = fromWeightedList $
do
246 | particles <- lift $
runPopulation pop
248 | let (log_ws, xs) = unzip particles
251 | = Numeric.Log.sum log_ws
253 | let normalized_log_ws : List (Log Double)
254 | = map (if fromLogDomain z > 0
256 | else const (toLogDomain (1.0 / cast (length log_ws)))) log_ws
259 | pure (zip normalized_log_ws xs)
268 | pushEvidence = hoist applyWeight . extractEvidence
278 | particles <- runPopulation $
extractEvidence pop
279 | let (log_ws_vec, xs_vec) = unzip (fromList particles)
280 | idx <- the (Weighted m (Fin (length particles))) (logCategorical log_ws_vec)
281 | pure (index idx xs_vec)
285 | evidence : (Monad m) => Population m a -> m (Log Double)
286 | evidence = extractWeight . runPopulation . extractEvidence
295 | collapse = applyWeight . proper