0 | module NN.Optimisers.Instances
4 | import Data.Container.Additive
6 | import NN.Optimisers.Definition
12 | GD : Num pType => Neg pType => Random pType =>
13 | (mon : ComMonoid pType) => FromDouble pType =>
14 | {default 0.001 lr : pType} -> Optimiser (Const pType) Unit
16 | (!% \(p, ()) => (
p ** \p' => (p - lr * p', ()))
)
23 | GA : Num pType => Neg pType => Random pType =>
24 | (mon : ComMonoid pType) => FromDouble pType =>
25 | {default 0.001 lr : pType} -> Optimiser (Const pType) Unit
27 | (!% \(p, ()) => (
p ** \p' => (p + lr * p', ()))
)
33 | momentumUpdate : Num pType => Neg pType =>
40 | momentumUpdate gamma p s p' = let s' = gamma * s + p'
41 | in (p - lr * s', s')
44 | lookAhead : Num pType =>
45 | (gamma, p, s : pType) ->
47 | lookAhead gamma p s = p + gamma * s
51 | GDMomentum : Num pType => Neg pType => Random pType =>
52 | (mon : ComMonoid pType) =>
54 | {default False nesterov : Bool} ->
55 | {default 0.001 lr : pType} ->
56 | {default 0.9 gamma : pType} ->
57 | Optimiser (Const pType) pType
58 | GDMomentum = MkOptimiser
59 | (!% \(p, s) => (
if nesterov then lookAhead gamma p s else p
60 | ** momentumUpdate {lr} gamma p s)
)
69 | adamUpdate : Num pType => Neg pType => Fractional pType => Sqrt pType =>
73 | (epsilon : pType) ->
80 | (pType, pType, pType, pType, pType)
81 | adamUpdate beta1 beta2 epsilon p m v b1p b2p g =
82 | let m' = beta1 * m + (1 - beta1) * g
83 | v' = beta2 * v + (1 - beta2) * g * g
86 | mHat = m' / (1 - b1p')
87 | vHat = v' / (1 - b2p')
88 | in (p - lr * mHat / (sqrt vHat + epsilon), m', v', b1p', b2p')
97 | Adam : Num pType => Neg pType => Random pType =>
98 | (mon : ComMonoid pType) =>
100 | Fractional pType => Sqrt pType =>
101 | {default 0.001 lr : pType} ->
102 | {default 0.9 beta1 : pType} ->
103 | {default 0.999 beta2 : pType} ->
104 | {default 1.0e-8 epsilon : pType} ->
105 | Optimiser (Const pType) (pType, pType, pType, pType)
107 | (!% \(p, (m, v, b1p, b2p)) =>
108 | (
p ** adamUpdate {lr} beta1 beta2 epsilon p m v b1p b2p)
)
109 | (randomRIO (-
1, 1))
110 | (pure (0, 0, 1, 1))