0 | module NN.Architectures.Softargmax
 1 |
 2 | import Data.Tensor
 3 | import Data.Para
 4 | import Control.Monad.Distribution
 5 |
 6 |
 7 |
 8 | ||| Numerically stable log-sum-exp operation
 9 | ||| LSE(x) = max(x) + log(Σᵢ exp(xᵢ - max(x)))
10 | ||| See https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
11 | public export
12 | logSumExp : {i : Axis} -> Exp a => Ord a => Neg a =>
13 |   Foldable (Tensor [i]) =>
14 |   (allAlg : AllAlgebra [i] a) =>
15 |   Tensor [i] a -> Maybe a
16 | logSumExp t = do
17 |   c <- max t
18 |   pure $ c + log (reduce (t <&> (\x => exp $ x - c)))
19 |
20 | ||| Log(softargmax(x)), but computationally efficient and numerically stable
21 | ||| Used for computing cross-entropy loss
22 | ||| Returns empty tensor for empty input
23 | public export
24 | logSoftargmax : {i : Axis} -> Exp a => Ord a => Neg a =>
25 |   Foldable (Tensor [i]) =>
26 |   (allAlg : AllAlgebra [i] a) =>
27 |   Tensor [i] a -> Tensor [i] a
28 | logSoftargmax t = case logSumExp t of
29 |   Just lse => t <&> (\x => x - lse) -- Non-empty: subtract LSE from each element
30 |   Nothing  => t                     -- t is empty
31 |
32 | ||| Commonly known as 'softmax'
33 | ||| When `temperature=0` it reduces to `argmax`
34 | public export
35 | softargmaxImpl : {i : Axis} -> Fractional a => Exp a => Ord a => Neg a =>
36 |   Foldable (Tensor [i]) =>
37 |   (allAlg : AllAlgebra [i] a) =>
38 |   {default 1 temperature : a} ->
39 |   Tensor [i] a -> Tensor [i] a
40 | softargmaxImpl {temperature} t
41 |   = exp <$> logSoftargmax (t <&> (/ temperature))
42 |
43 | ||| Softargmax as a parametric map, with temperature as a parameter
44 | ||| TODO the output type should be a distribution tensor, since distributions
45 | ||| are applicative? https://glaive-research.org/2025/02/11/Generalized-Transformers-from-Applicative-Functors.html
46 | public export
47 | softargmax : {i : Axis} ->
48 |   {a : Type} -> Fractional a => Exp a => Ord a => Neg a =>
49 |   Foldable (Tensor [i]) =>
50 |   (allAlg : AllAlgebra [i] a) =>
51 |   Tensor [i] a -\-> Tensor [i] a
52 | softargmax = MkPara 
53 |   (\_ => a) -- temperature is the parameter
54 |   (\(t ** temperature=> softargmaxImpl {temperature} t)
55 |
56 |
57 | -- `Control.Monad.Distribution` and softargmax should probably be merged?
58 | public export
59 | {i : Nat} -> Show (Dist i) where
60 |   show (MkDist xs) = show (softargmaxImpl {i="softmaxTemp" ~~> i} (># xs))
61 |
62 | inpp : Tensor ["ieva" ~~> 3] Double
63 | inpp = ># [1000, 999, 998]
64 |
65 | -- TODO namedSoftargmax
66 | -- namedSoftmax : {axis : Type -> Type}
67 | --   -> {shape : Vect n ApplF} -> {a : Type}
68 | --   -> Functor axis
69 | --   => Elem axis shape
70 | --   -> TensorA shape a
71 | --   -> TensorA shape a
72 | -- namedSoftmax {shape = []} axis t impossible -- can't be in vector if vector empty
73 | -- namedSoftmax {shape = (axis :: ss)} Here (GTS x) = GTS (?sm <$> x)
74 | -- namedSoftmax {shape = (s :: ss)} (There later) (GTS x) = GTS ?namedSoftmax_rhs_4
75 |