0 | module NN.Architectures.Softargmax
4 | import Control.Monad.Distribution
12 | logSumExp : {i : Axis} -> Exp a => Ord a => Neg a =>
13 | Foldable (Tensor [i]) =>
14 | (allAlg : AllAlgebra [i] a) =>
15 | Tensor [i] a -> Maybe a
18 | pure $
c + log (reduce (t <&> (\x => exp $
x - c)))
24 | logSoftargmax : {i : Axis} -> Exp a => Ord a => Neg a =>
25 | Foldable (Tensor [i]) =>
26 | (allAlg : AllAlgebra [i] a) =>
27 | Tensor [i] a -> Tensor [i] a
28 | logSoftargmax t = case logSumExp t of
29 | Just lse => t <&> (\x => x - lse)
35 | softargmaxImpl : {i : Axis} -> Fractional a => Exp a => Ord a => Neg a =>
36 | Foldable (Tensor [i]) =>
37 | (allAlg : AllAlgebra [i] a) =>
38 | {default 1 temperature : a} ->
39 | Tensor [i] a -> Tensor [i] a
40 | softargmaxImpl {temperature} t
41 | = exp <$> logSoftargmax (t <&> (/ temperature))
47 | softargmax : {i : Axis} ->
48 | {a : Type} -> Fractional a => Exp a => Ord a => Neg a =>
49 | Foldable (Tensor [i]) =>
50 | (allAlg : AllAlgebra [i] a) =>
51 | Tensor [i] a -\-> Tensor [i] a
54 | (\(
t ** temperature)
=> softargmaxImpl {temperature} t)
59 | {i : Nat} -> Show (Dist i) where
60 | show (MkDist xs) = show (softargmaxImpl {i="softmaxTemp" ~~> i} (># xs))
62 | inpp : Tensor ["ieva" ~~> 3] Double
63 | inpp = ># [1000, 999, 998]