0 | module Control.Monad.Sample.Instances
2 | import Control.Monad.Identity
6 | import Control.Monad.Distribution
7 | import Control.Monad.Sample.Definition
8 | import NN.Architectures.Softargmax
12 | [pickFirst] MonadSample Identity where
13 | sample {i = (S k)} (MkDist xs) = Id FZ
17 | [pickMax] MonadSample Identity where
18 | sample {i = (S k)} (MkDist xs) = Id (argmax xs)
22 | [pickMin] MonadSample Identity where
23 | sample {i = (S k)} (MkDist xs) = Id (argmin xs)
28 | MonadSample IO where
29 | sample @{ItIsSucc} (MkDist xs) = do
30 | let dist : Tensor ["dist" ~~> i] Double := (softargmaxImpl {i="dist" ~~> i}) (># xs)
31 | cumSum : Tensor ["dist" ~~> i] Double := cumulativeSum dist
32 | r <- randomRIO (0.0, 1.0)
33 | case findBin (#> cumSum) r of
41 | let logits = MkDist [-(1.099), 1.099]
42 | is <- sequence (replicate 1000 (sample logits))
44 | printLn (count (== 0) is)
45 | printLn (count (== 1) is)
51 | let logits = diracDelta {i=10} index
52 | inds <- sequence (replicate 1000 (sample logits))
53 | printLn (take 10 inds)
54 | printLn (count (== index) inds)