0 | module NN.Architectures.Softargmax
4 | import Control.Monad.Distribution
12 | logSumExp : {i : Axis} -> Exp a => Ord a => Neg a =>
13 | Foldable (Tensor [i]) =>
14 | (allAlg : AllAlgebra [i] a) =>
15 | Tensor [i] a -> Maybe a
18 | pure $
c + log (reduce (t <&> (\x => exp $
x - c)))
24 | logSoftargmax : {i : Axis} -> Exp a => Ord a => Neg a =>
25 | Foldable (Tensor [i]) =>
26 | (allAlg : AllAlgebra [i] a) =>
27 | Tensor [i] a -> Tensor [i] a
28 | logSoftargmax t = case logSumExp t of
29 | Just lse => t <&> (\x => x - lse)
35 | softargmaxImpl : {i : Axis} -> Fractional a => Exp a => Ord a => Neg a =>
36 | IsFoldable i .cont =>
37 | (allAlg : AllAlgebra [i] a) =>
38 | {default 1 temperature : a} ->
39 | Tensor [i] a -> Tensor [i] a
40 | softargmaxImpl {temperature} t
41 | = exp <$> logSoftargmax (t <&> (/ temperature))
48 | softargmax : {i : Axis} ->
49 | {a : Type} -> Fractional a => Exp a => Ord a => Neg a =>
50 | IsFoldable i.cont =>
51 | (allAlg : AllAlgebra [i] a) =>
52 | Tensor [i] a -\-> Tensor [i] a
55 | (\(
t ** temperature)
=> softargmaxImpl {temperature} t)
62 | {i : Nat} -> Show (Dist i) where
63 | show (MkDist xs) = assert_total $
64 | show @{(?todoTensorShow
)} (softargmaxImpl {i="softmaxTemp" ~~> i} (># xs))
66 | inpp : Tensor ["ieva" ~~> 3] Double
67 | inpp = ># [1000, 999, 998]