0 | module Control.Monad.Sample.Instances
 1 |
 2 | import Control.Monad.Identity
 3 | import System.Random
 4 |
 5 | import Data.Tensor
 6 | import Control.Monad.Distribution
 7 | import Control.Monad.Sample.Definition
 8 | import NN.Architectures.Softargmax
 9 |
10 | ||| Trivial sampler, always picks the first element
11 | public export
12 | [pickFirst] MonadSample Identity where
13 |   sample {i = (S k)} (MkDist xs) = Id FZ
14 |
15 | ||| Max sampler, always picks the element with the highest logit
16 | public export
17 | [pickMax] MonadSample Identity where
18 |   sample {i = (S k)} (MkDist xs) = Id (argmax xs)
19 |
20 | ||| Min sampler, always picks the element with the lowest logit
21 | public export
22 | [pickMin] MonadSample Identity where
23 |   sample {i = (S k)} (MkDist xs) = Id (argmin xs)
24 |
25 |
26 | ||| Computes the cumulative distribution, samples randomly, finds the right bin
27 | public export
28 | MonadSample IO where
29 |   sample @{ItIsSucc} (MkDist xs) = do
30 |     let dist : Tensor ["dist" ~~> i] Double := (softargmaxImpl {i="dist" ~~> i}) (># xs)
31 |         cumSum : Tensor ["dist" ~~> i] Double := cumulativeSum dist
32 |     r <- randomRIO (0.0, 1.0)
33 |     case findBin (#> cumSum) r of
34 |       Nothing => pure FZ -- should never happen!
35 |       Just i => pure i
36 |
37 |
38 |
39 | testIO : IO ()
40 | testIO = do
41 |   let logits = MkDist [-(1.099), 1.099] -- this produes the dist [0.1, 0.9]
42 |   is <- sequence (replicate 1000 (sample logits))
43 |   -- printLn is
44 |   printLn (count (== 0) is) -- should be ~100
45 |   printLn (count (== 1) is) -- should be ~900
46 |
47 | public export
48 | testDirac : IO ()
49 | testDirac = do
50 |   let index = 4
51 |   let logits = diracDelta {i=10} index
52 |   inds <- sequence (replicate 1000 (sample logits))
53 |   printLn (take 10 inds)
54 |   printLn (count (== index) inds)