0 | module Control.Monad.Sample.Instances
 1 |
 2 | import Control.Monad.Identity
 3 | import System.Random
 4 |
 5 | import Data.Tensor
 6 | import Control.Monad.Distribution
 7 | import Control.Monad.Sample.Definition
 8 | import NN.Architectures.Softargmax
 9 |
10 | ||| Trivial sampler, always picks the first element
11 | public export
12 | [pickFirst] MonadSample Identity where
13 |   sample {i = (S k)} (MkDist xs) = Id FZ
14 |
15 | ||| Max sampler, always picks the element with the highest logit
16 | public export
17 | [pickMax] MonadSample Identity where
18 |   sample {i = (S k)} (MkDist xs) = Id (argmax xs)
19 |
20 | ||| Min sampler, always picks the element with the lowest logit
21 | public export
22 | [pickMin] MonadSample Identity where
23 |   sample {i = (S k)} (MkDist xs) = Id (argmin xs)
24 |
25 |
26 | ||| Computes the cumulative distribution, samples randomly, finds the right bin
27 | public export
28 | MonadSample IO where
29 |   sample {i = S j} (MkDist xs) = do
30 |     let dist : Tensor ["dist" ~~> S j] Double
31 |         dist = softargmaxImpl {i="dist" ~~> S j} (># xs)
32 |         cumSum : Tensor ["dist" ~~> S j] Double
33 |         cumSum = cumulativeSum dist
34 |     r <- randomRIO (0.0, 1.0)
35 |     case findBin (#> cumSum) r of
36 |       Nothing => pure FZ -- should never happen!
37 |       Just i => pure i
38 |
39 |
40 |
41 | testIO : IO ()
42 | testIO = do
43 |   let logits = MkDist [-(1.099), 1.099] -- this produes the dist [0.1, 0.9]
44 |   is <- sequence (replicate 1000 (sample logits))
45 |   -- printLn is
46 |   printLn (count (== 0) is) -- should be ~100
47 |   printLn (count (== 1) is) -- should be ~900
48 |
49 | public export
50 | testDirac : IO ()
51 | testDirac = do
52 |   let index = 4
53 |   let logits = diracDelta {i=10} index
54 |   inds <- sequence (replicate 1000 (sample logits))
55 |   printLn (take 10 inds)
56 |   printLn (count (== index) inds)