0 | module NN.Optimisers.Instances
  1 |
  2 | import System.Random
  3 |
  4 | import Data.Container.Additive
  5 | import Data.Num
  6 | import NN.Optimisers.Definition
  7 | import NN.Utils
  8 |
  9 | ||| Gradient descent optimiser. Has trivial state
 10 | ||| @lr is the learning rate
 11 | public export
 12 | GD : Num pType => Neg pType => Random pType =>
 13 |   (mon : ComMonoid pType) => FromDouble pType =>
 14 |   {default 0.001 lr : pType} -> Optimiser (Const pType) Unit
 15 | GD = MkOptimiser
 16 |   (!% \(p, ()) => (p ** \p' => (p - lr * p', ())))
 17 |   (randomRIO (-1, 1))
 18 |   (pure ())
 19 |
 20 | ||| Gradient ascent optimiser. Has trivial state
 21 | ||| @lr is the learning rate
 22 | public export
 23 | GA : Num pType => Neg pType => Random pType =>
 24 |   (mon : ComMonoid pType) => FromDouble pType =>
 25 |   {default 0.001 lr : pType} -> Optimiser (Const pType) Unit
 26 | GA = MkOptimiser
 27 |   (!% \(p, ()) => (p ** \p' => (p + lr * p', ())))
 28 |   (randomRIO (-1, 1))
 29 |   (pure ())
 30 |
 31 | namespace Momentum
 32 |   public export
 33 |   momentumUpdate : Num pType => Neg pType =>
 34 |     {lr : pType} ->
 35 |     (gamma : pType) ->
 36 |     (p : pType) ->
 37 |     (s : pType) ->
 38 |     (p' : pType) ->
 39 |     (pType, pType)
 40 |   momentumUpdate gamma p s p' = let s' = gamma * s + p'
 41 |                                 in (p - lr * s', s')
 42 |
 43 |   public export
 44 |   lookAhead : Num pType =>
 45 |     (gamma, p, s : pType) ->
 46 |     pType
 47 |   lookAhead gamma p s = p + gamma * s
 48 |   
 49 |   ||| Gradient Descent with momentum, optionally with Nesterov acceleration
 50 |   public export
 51 |   GDMomentum : Num pType => Neg pType => Random pType =>
 52 |    (mon : ComMonoid pType) =>
 53 |    FromDouble pType =>
 54 |    {default False nesterov : Bool} ->
 55 |    {default 0.001 lr : pType} ->
 56 |    {default 0.9 gamma : pType} ->
 57 |    Optimiser (Const pType) pType
 58 |   GDMomentum = MkOptimiser
 59 |     (!% \(p, s) => (if nesterov then lookAhead gamma p s else p
 60 |                    ** momentumUpdate {lr} gamma p s))
 61 |     (randomRIO (-1, 1))
 62 |     (pure 0)
 63 |   
 64 | namespace Adam
 65 |   ||| Adam update step
 66 |   ||| State is (m, v, beta1^t, beta2^t) where m and v are the first and second
 67 |   ||| moment estimates, and beta1^t, beta2^t are running powers for bias correction
 68 |   public export
 69 |   adamUpdate : Num pType => Neg pType => Fractional pType => Sqrt pType =>
 70 |     {lr : pType} ->
 71 |     (beta1 : pType) ->
 72 |     (beta2 : pType) ->
 73 |     (epsilon : pType) ->
 74 |     (p : pType) ->
 75 |     (m : pType) ->
 76 |     (v : pType) ->
 77 |     (b1p : pType) ->
 78 |     (b2p : pType) ->
 79 |     (g : pType) ->
 80 |     (pType, pType, pType, pType, pType)
 81 |   adamUpdate beta1 beta2 epsilon p m v b1p b2p g =
 82 |     let m' = beta1 * m + (1 - beta1) * g
 83 |         v' = beta2 * v + (1 - beta2) * g * g
 84 |         b1p' = b1p * beta1
 85 |         b2p' = b2p * beta2
 86 |         mHat = m' / (1 - b1p')
 87 |         vHat = v' / (1 - b2p')
 88 |     in (p - lr * mHat / (sqrt vHat + epsilon), m', v', b1p', b2p')
 89 |
 90 |   ||| Adam optimiser (Kingma & Ba, 2014)
 91 |   ||| Using 4 parameters for state for efficiency
 92 |   ||| @lr is the learning rate
 93 |   ||| @beta1 is the exponential decay rate for the first moment estimate
 94 |   ||| @beta2 is the exponential decay rate for the second moment estimate
 95 |   ||| @epsilon is a small constant for numerical stability
 96 |   public export
 97 |   Adam : Num pType => Neg pType => Random pType =>
 98 |    (mon : ComMonoid pType) =>
 99 |    FromDouble pType =>
100 |    Fractional pType => Sqrt pType =>
101 |    {default 0.001 lr : pType} ->
102 |    {default 0.9 beta1 : pType} ->
103 |    {default 0.999 beta2 : pType} ->
104 |    {default 1.0e-8 epsilon : pType} ->
105 |    Optimiser (Const pType) (pType, pType, pType, pType)
106 |   Adam = MkOptimiser
107 |     (!% \(p, (m, v, b1p, b2p)) =>
108 |       (p ** adamUpdate {lr} beta1 beta2 epsilon p m v b1p b2p))
109 |     (randomRIO (-1, 1))
110 |     (pure (0, 0, 1, 1))