0 | module NN.Architectures.LossFunctions
  1 |
  2 | import Data.List
  3 | import Data.Zippable
  4 |
  5 | import Data.Tensor
  6 | import Data.Tensor.Utils
  7 | import Data.Container.Additive
  8 | import NN.Architectures.Softargmax
  9 | import Control.Monad.Distribution
 10 |
 11 | import Data.Container.Additive.Quantifiers
 12 |
 13 | import Data.Para
 14 |
 15 | %hide Data.Container.Base.Morphism.Definition.DependentLenses.(=%>)
 16 |
 17 | ||| Loss function alias
 18 | public export
 19 | Loss : AddCont -> {default (Const Double) l : AddCont} -> Type
 20 | Loss c = (c >< c) =%> l
 21 |
 22 | namespace Combinators
 23 |   ||| Combinator for pairing up loss functions
 24 |   public export
 25 |   pairLossFunctions : {y, z : AddCont} ->
 26 |     {l : Type} -> Num l =>
 27 |     Loss y {l=Const l} -> Loss z {l=Const l} -> Loss (y >< z) {l=Const l}
 28 |   pairLossFunctions loss1 loss2 = swapMiddle %>> (loss1 >< loss2) %>> Sum
 29 |
 30 |   public export
 31 |   lossSame : {a, b : AddCont} ->
 32 |     ((a >+< b) >< (a >+< b)) =%> ((a >< a) >+< (b >< b))
 33 |   lossSame = !%+ \case
 34 |     (Left x1, Left x2) => (Left (x1, x2) ** id)
 35 |     (Right y1, Right y2) => (Right (y1, y2) ** id)
 36 |     (_, _) => believe_me "Should not happen"
 37 |
 38 |   public export
 39 |   pairLossCoproduct : {y, z : AddCont} ->
 40 |     {l : Type} -> Num l =>
 41 |     Loss y {l=Const l} -> Loss z {l=Const l} -> Loss (y >+< z) {l=Const l}
 42 |   pairLossCoproduct l1 l2 = lossSame %>> (l1 >+< l2) %>> elim
 43 |
 44 |   public export
 45 |   composeLossFunctions : {y, z, l : AddCont} ->
 46 |     Loss y {l} -> Loss z {l} -> Loss (y >@ z) {l}
 47 |   composeLossFunctions loss1 loss2 = let tt = loss1 >@ loss2
 48 |                                      in ?composeLossFunctions_rhs
 49 |
 50 |   public export
 51 |   sequenceLossFunctions : {y, z : AddCont} ->
 52 |     Loss y {l} -> Loss z {l} -> Loss (y >@ z) {l}
 53 |   sequenceLossFunctions loss1 loss2 = !%+ \(x1, x2) => ?asdf
 54 |
 55 |   public export
 56 |   zipListsBwd : {y : AddCont} ->
 57 |     (l1, l2 : List y.Shp) ->
 58 |     All (y >< y).Pos (zip l1 l2) -> (All y.Pos l1, All y.Pos l2)
 59 |   zipListsBwd [] l2 [] = ([], allIsComonoidNeutral l2)
 60 |   zipListsBwd (s1 :: ss1) [] [] = (allIsComonoidNeutral (s1 :: ss1), [])
 61 |   zipListsBwd (s1 :: ss1) (s2 :: ss2) ((p1, p2) :: rest) =
 62 |     let (ls, rs) = zipListsBwd ss1 ss2 rest
 63 |     in (p1 :: ls, p2 :: rs)
 64 |
 65 |   public export
 66 |   zipLists : {y : AddCont} -> (List y) >< (List y) =%> List (y >< y)
 67 |   zipLists = !%+ \(l1, l2) => (zip l1 l2 ** zipListsBwd l1 l2)
 68 |
 69 |   -- TODO here it can be that we pair different types together!
 70 |   -- so if there is a mismath we have the ability to short circuit?
 71 |   public export
 72 |   UniversalMapOutOfCoproduct : Num d =>
 73 |     {n : Nat} -> IsSucc n =>
 74 |     {cs : Vect n AddCont} ->
 75 |     ((i : Fin n) -> Loss (index i cs) {l=Const d}) ->
 76 |     Loss (Any cs) {l=Const d}
 77 |   UniversalMapOutOfCoproduct {n = 1} {cs = [c]} s = !%+ \(Here l, Here r) =>
 78 |     ((s 0).fwd (l, r) ** \d' =>
 79 |     let (l', r') = (s 0).bwd (l, r) d'
 80 |     in (Here l', Here r'))
 81 |   UniversalMapOutOfCoproduct {n = (S (S k))} {cs = (c :: cs)} s
 82 |     = !%+ \case 
 83 |       (Here l, Here r) => ((s 0).fwd (l, r) ** \d' =>
 84 |         let (l', r') = (s 0).bwd (l, r) d'
 85 |         in (Here l', Here r'))
 86 |       (There l, (There r)) =>
 87 |          let restLens = UniversalMapOutOfCoproduct {cs=cs} (\i => s (FS i))
 88 |              d = restLens.fwd (l, r)
 89 |          in (d ** \d' => let (l', r') = restLens.bwd (l, r) d'
 90 |                          in (There l', There r'))
 91 |       -- if branches mismatch we shouldn't be asked this question
 92 |       -- using zeros for now
 93 |       (Here l, (There r)) => (0  ** \_ =>
 94 |         (Here (c.Zero l), There ((Any cs).Zero r)))
 95 |       (There l, (Here r)) => (0 ** \_ =>
 96 |         (There ((Any cs).Zero l), (Here (c.Zero r))))
 97 |
 98 | ||| Squared error
 99 | public export
100 | SquaredError : {a : Type} -> Num a => Neg a => Loss (Const a) {l=Const a}
101 | SquaredError = Additive.Morphism.Instances.SquaredDifference
102 |
103 | public export
104 | Sum : {n : Axis} -> IsCubical n => Num a =>
105 |   TensorMonoid n.cont => 
106 |   (Const (Tensor [n] a)) =%> (Const (Tensor [] a))
107 | Sum @{MkIsCubical _ n} = !%+ \t => (># reduce t ** \a' => fill (#> a'))
108 |
109 | public export
110 | Div : {a : Type} -> Num a => Fractional a =>
111 |   (divBy : a) ->
112 |   (Const (Tensor [] a)) =%> (Const (Tensor [] a))
113 | Div divBy = !%+ \x => (x <&> (/ divBy) ** \x' => x' <&> (/ divBy))
114 |
115 | public export
116 | MeanSquaredError : {n : Axis} -> IsCubical n => TensorMonoid n.cont =>
117 |   {a : Type} -> Num a => Neg a => Fractional a => Cast Nat a =>
118 |   Loss (Const (Tensor [n] a)) {l=Const (Tensor [] a)}
119 | MeanSquaredError @{MkIsCubical _ n} = SquaredError %>> Sum %>> Div (cast n)
120 |
121 | public export
122 | SoftargmaxCrossEntropyLogits : {n : Nat} -> Loss (Simplex n)
123 | SoftargmaxCrossEntropyLogits = !%+ \(logits, labels) =>
124 |   let logSoftargmaxLogits = logSoftargmax (># (toVect logits))
125 |       targetProbs = softargmaxImpl {i="softargmaxTemp" ~~> n} (># (toVect labels))
126 |       out = - dot logSoftargmaxLogits targetProbs
127 |   in (extract out ** \l' =>
128 |     (#> (((l' *) <$> softargmaxImpl (># (toVect logits)) - targetProbs)),
129 |       replicate n 0)-- zeros for now
130 |