0 | module Control.Monad.Bayes.Population
  1 |
  2 | import public Control.Monad.Bayes.Weighted
  3 | import Control.Monad.Bayes.Interface
  4 | import Control.Monad.Bayes.Sampler
  5 | import Control.Monad.Trans
  6 | import Numeric.Log
  7 | import Data.List
  8 | import Debug.Trace
  9 |
 10 | ||| List transformer
 11 | public export
 12 | record ListT (m : Type -> Type) (a : Type) where
 13 |   constructor MkListT
 14 |   runListT : m (List a)
 15 |
 16 | mapListT : (m (List a) -> n (List b)) -> ListT m a -> ListT n b
 17 | mapListT f m = MkListT $ f (runListT m)
 18 |
 19 | export
 20 | Functor m => Functor (ListT m) where
 21 |   map f  = mapListT $ map $ map f
 22 |
 23 | export
 24 | Applicative m => Applicative (ListT m) where
 25 |   pure x  = MkListT (pure [x])
 26 |   f <*> v = MkListT $ (<*>) <$> runListT f <*> runListT v
 27 |
 28 | export
 29 | Monad m => Monad (ListT m) where
 30 |   m >>= k  = MkListT $ do
 31 |     a <- runListT m
 32 |     b <- (sequence . map (runListT . k))  a
 33 |     pure (join b)
 34 |
 35 | export
 36 | MonadTrans ListT where
 37 |   lift = MkListT . map List.singleton
 38 |
 39 | export
 40 | MonadSample m => MonadSample (ListT m) where
 41 |   random = lift random
 42 |   bernoulli = lift . bernoulli
 43 |   categorical = lift . categorical
 44 |
 45 | export
 46 | MonadCond m => MonadCond (ListT m) where
 47 |   score = lift . score
 48 |
 49 | export
 50 | MonadInfer m => MonadInfer (ListT m) where
 51 |
 52 | ||| A collection of weighted samples, or particles.
 53 | public export
 54 | record Population (m : Type -> Type) (a : Type) where
 55 |   constructor MkPopulation
 56 |   runPopulation' : Weighted (ListT m) a
 57 |
 58 | export
 59 | Functor m => Functor (Population m) where
 60 |   map f (MkPopulation mx) = MkPopulation (map f mx)
 61 |
 62 | export
 63 | Monad m => Applicative (Population m) where
 64 |   pure = MkPopulation . pure
 65 |   (MkPopulation mf) <*> (MkPopulation ma) = MkPopulation (mf <*> ma)
 66 |
 67 | export
 68 | Monad m => Monad (Population m) where
 69 |   (MkPopulation mx) >>= k = MkPopulation (mx >>= (runPopulation' . k))
 70 |
 71 | export
 72 | MonadTrans Population where
 73 |   lift = MkPopulation . lift . lift
 74 |
 75 | export
 76 | MonadSample m => MonadSample (Population m) where
 77 |   random = lift random
 78 |
 79 | export
 80 | Monad m => MonadCond (Population m) where
 81 |   score w = MkPopulation $ score w -- Call score from Weighted
 82 |
 83 | export
 84 | MonadSample m => MonadInfer (Population m) where
 85 |
 86 | ||| Explicit representation of the weighted sample with weights in the log domain.
 87 | export
 88 | runPopulation : Population m a -> m (List (Log Double, a))
 89 | runPopulation (MkPopulation m) = (runListT . runWeighted) m
 90 |
 91 | ||| Explicit representation of the weighted sample.
 92 | export
 93 | explicitPopulation : Functor m => Population m a -> m (List (Double, a))
 94 | explicitPopulation = map (map (\(log_w, a) => (fromLogDomain log_w, a))) . runPopulation
 95 |
 96 | ||| Initialize 'Population' with a concrete weighted sample.
 97 | export
 98 | fromWeightedList : Monad m => m (List (Log Double, a)) -> Population m a
 99 | fromWeightedList = MkPopulation . withWeight . MkListT
100 |
101 | ||| Applies a transformation to the inner monad.
102 | export
103 | hoist :
104 |   Monad m2 =>
105 |   (forall x. m1 x -> m2 x) ->
106 |   Population m1 a ->
107 |   Population m2 a
108 | hoist f = fromWeightedList . f . runPopulation
109 |
110 | ||| Increase the sample size by a given factor.
111 | ||| The weights are adjusted such that their sum is preserved. It is therefore
112 | ||| safe to use 'spawn' in arbitrary places in the program without introducing bias.
113 | export
114 | spawn : (isMonad : Monad m) => Nat -> Population m ()
115 | spawn n = fromWeightedList $ pure $ replicate n (toLogDomain (1.0 / cast n), ())
116 |
117 | export
118 | resampleGeneric :
119 |   MonadSample m =>
120 |   -- | resampler
121 |   ({k : Nat} -> Vect k Double -> m (List (Fin k))) ->
122 |   Population m a ->
123 |   Population m a
124 | resampleGeneric resampler pop = fromWeightedList $ do
125 |   particles <- runPopulation pop
126 |   let (log_ws, xs)  : (Vect (length particles) (Log Double), Vect (length particles) a)
127 |                     = unzip (fromList particles)
128 |       z : Log Double = Numeric.Log.sum log_ws
129 |   if fromLogDomain z > 0
130 |     then do
131 |             let weights    : Vect (length particles) Double
132 |                            = map (fromLogDomain . (/ z)) log_ws
133 |             ancestors <- resampler weights
134 |             let offsprings : List a
135 |                            = map (\idx => index idx xs) ancestors
136 |             pure $ map (z / (toLogDomain $ length particles), ) offsprings
137 |     else
138 |             pure particles
139 |
140 |
141 | ||| Systematic sampler.
142 | export
143 | systematic : {n : Nat} -> Double -> Vect n Double -> List (Fin n)
144 | systematic {n = Z}   u Nil = Nil
145 | systematic {n = S k} u ws =
146 |   let
147 |           prob : Maybe (Fin (S k)) -> Double
148 |           prob (Just idx) = index idx ws
149 |           prob  Nothing   = index last ws
150 |
151 |           inc : Double
152 |           inc = 1 / cast (S k)
153 |
154 |           f : Nat -> Double -> Nat -> Double -> List Nat -> List Nat
155 |           f i v j q acc =
156 |             if i == S k then acc else
157 |             if v < q
158 |               then f (S i) (v + inc) j    q                             ((minus j 1) :: acc)
159 |               else f  i    v        (S j) (q + prob (natToFin j (S k))) acc
160 |
161 |           particle_idxs : List (Fin (S k))
162 |           particle_idxs = map (\nat => fromMaybe FZ (natToFin nat (S k)))
163 |                               (f Z (u / cast (S k)) Z 0.0 [])
164 |
165 |   in      particle_idxs
166 |
167 | ||| Resample the population using the underlying monad and a systematic resampling scheme.
168 | ||| The total weight is preserved.
169 | export
170 | resampleSystematic :
171 |   (MonadSample m) =>
172 |   Population m a ->
173 |   Population m a
174 | resampleSystematic = resampleGeneric (\ws => (`systematic` ws) <$> random)
175 |
176 | ||| The conditional variance of stratified sampling is always smaller than that of multinomial
177 | ||| sampling and it is also unbiased - see [Comparison of Resampling Schemes for Particle
178 | ||| Filtering](https://arxiv.org/abs/cs/0507025).
179 | export
180 | stratified : MonadSample m => {n : Nat} -> Vect n Double -> m (List (Fin n))
181 | stratified {n = Z}   Nil     = pure Nil
182 | stratified {n = S k} weights = do
183 |   dithers <- sequence (Vect.replicate (S k) (uniform 0.0 1.0))
184 |
185 |   let vect_range : (n : Nat) -> Vect (S n) Nat
186 |       vect_range n = reverse (range_bwd n)
187 |         where range_bwd : (n : Nat) -> Vect (S n) Nat
188 |               range_bwd (S k) = (S k) :: range_bwd k
189 |               range_bwd Z     = Z :: Nil
190 |
191 |       positions : Vect (S k) Double
192 |       positions = map (/cast (S k)) $ zipWith (+) dithers (map cast (vect_range k))
193 |
194 |       cumulativeSum : Vect (S (S k)) Double
195 |       cumulativeSum = scanl (+) 0.0 weights
196 |
197 |       coalg : (Nat, Nat) -> Maybe (Maybe (Fin (S (S k))), (Nat, Nat))
198 |       coalg (i, j) with (natToFin i (S k), natToFin j (S (S k)))
199 |         _ | (Just i_fin, Just j_fin) with (index i_fin positions < index j_fin cumulativeSum)
200 |           _ | True  = Just (Just j_fin, (i + 1, j))
201 |           _ | False = Just (Nothing, (i, j + 1))
202 |         _ | otherwise = Nothing
203 |
204 |       fin_pred : Fin (S (S k)) -> Fin (S k)
205 |       fin_pred (FS k) = k
206 |       fin_pred FZ     = FZ
207 |
208 |       particle_idxs : List (Fin (S k))
209 |       particle_idxs = map (fin_pred) (catMaybes $ unfoldr coalg (0,0))
210 |
211 |   pure particle_idxs
212 |
213 | ||| Resample the population using the underlying monad and a stratified resampling scheme.
214 | ||| The total weight is preserved.
215 | export
216 | resampleStratified :
217 |   (MonadSample m) =>
218 |   Population m a ->
219 |   Population m a
220 | resampleStratified = resampleGeneric stratified
221 |
222 | ||| Multinomial sampler.  Sample from \(0, \ldots, n - 1\) \(n\)
223 | ||| times drawn at random according to the weights where \(n\) is the
224 | ||| length of vector of weights.
225 | export
226 | multinomial : MonadSample m => {n : Nat} -> Vect n Double -> m (List (Fin n))
227 | multinomial ws = sequence $ replicate n (categorical ws)
228 |
229 | ||| Resample the population using the underlying monad and a multinomial resampling scheme.
230 | ||| The total weight is preserved.
231 | export
232 | resampleMultinomial :
233 |   (MonadSample m) =>
234 |   Population m a ->
235 |   Population m a
236 | resampleMultinomial = resampleGeneric multinomial
237 |
238 | ||| Separate the sum of weights into the 'Weighted' transformer.
239 | ||| Weights are normalized after this operation.
240 | export
241 | extractEvidence :
242 |   Monad m =>
243 |   Population m a ->
244 |   Population (Weighted m) a
245 | extractEvidence pop = fromWeightedList $ do
246 |   particles <- lift $ runPopulation pop
247 |
248 |   let (log_ws, xs) = unzip particles
249 |
250 |   let z      : Log Double
251 |              = Numeric.Log.sum log_ws
252 |
253 |   let normalized_log_ws : List (Log Double)
254 |              = map (if fromLogDomain z > 0
255 |                       then (/ z)
256 |                       else const (toLogDomain (1.0 / cast (length log_ws)))) log_ws
257 |   score z
258 |
259 |   pure (zip normalized_log_ws xs)
260 |
261 | ||| Push the evidence estimator as a score to the transformed monad.
262 | ||| Weights are normalized after this operation.
263 | export
264 | pushEvidence :
265 |   MonadCond m =>
266 |   Population m a ->
267 |   Population m a
268 | pushEvidence = hoist applyWeight . extractEvidence
269 |
270 | ||| A properly weighted single sample, that is one picked at random according
271 | ||| to the weights, with the sum of all weights.
272 | export
273 | proper :
274 |   (MonadSample m) =>
275 |   Population m a ->
276 |   Weighted m a
277 | proper pop = do
278 |   particles <- runPopulation $ extractEvidence pop
279 |   let (log_ws_vec, xs_vec) = unzip (fromList particles)
280 |   idx <- the (Weighted m (Fin (length particles))) (logCategorical log_ws_vec)
281 |   pure (index idx xs_vec)
282 |
283 | ||| Model evidence estimator, also known as pseudo-marginal likelihood.
284 | export
285 | evidence : (Monad m) => Population m a -> m (Log Double)
286 | evidence = extractWeight . runPopulation . extractEvidence
287 |
288 | ||| Picks one point from the population and uses model evidence as a 'score' in the transformed monad.
289 | ||| This way a single sample can be selected from a population without introducing bias.
290 | export
291 | collapse :
292 |   (MonadInfer m) =>
293 |   Population m a ->
294 |   m a
295 | collapse = applyWeight . proper