0 | module Control.Monad.Sample.Instances
2 | import Control.Monad.Identity
6 | import Control.Monad.Distribution
7 | import Control.Monad.Sample.Definition
8 | import NN.Architectures.Softargmax
12 | [pickFirst] MonadSample Identity where
13 | sample {i = (S k)} (MkDist xs) = Id FZ
17 | [pickMax] MonadSample Identity where
18 | sample {i = (S k)} (MkDist xs) = Id (argmax xs)
22 | [pickMin] MonadSample Identity where
23 | sample {i = (S k)} (MkDist xs) = Id (argmin xs)
28 | MonadSample IO where
29 | sample {i = S j} (MkDist xs) = do
30 | let dist : Tensor ["dist" ~~> S j] Double
31 | dist = softargmaxImpl {i="dist" ~~> S j} (># xs)
32 | cumSum : Tensor ["dist" ~~> S j] Double
33 | cumSum = cumulativeSum dist
34 | r <- randomRIO (0.0, 1.0)
35 | case findBin (#> cumSum) r of
43 | let logits = MkDist [-(1.099), 1.099]
44 | is <- sequence (replicate 1000 (sample logits))
46 | printLn (count (== 0) is)
47 | printLn (count (== 1) is)
53 | let logits = diracDelta {i=10} index
54 | inds <- sequence (replicate 1000 (sample logits))
55 | printLn (take 10 inds)
56 | printLn (count (== index) inds)