0 | module NN.Optimisers.Definition
 1 |
 2 | import Data.Tensor
 3 | import Data.Container.Additive
 4 | import NN.Utils
 5 |
 6 | -- We can make a choice in optimisation:
 7 | -- a) either the parameter can depend on the input, in which case we can't use the hom representation of a learner
 8 | -- b) or it doesn't, meaning we can curry the learner and treat it as something we optimise
 9 |
10 |
11 | ||| Dependent stateful optimiser, modelled as a dependent lens
12 | ||| Dependent version of section 8.1.3 of https://arxiv.org/abs/2403.13001
13 | ||| Because we use dependent Para, optimiser can depend on the input
14 | public export
15 | record DOptimiser {inputTy : Type}
16 |   (paramCont : inputTy -> AddCont)
17 |   (stateTy : Type)
18 |   where
19 |   constructor MkDOptimiser
20 |   ||| Notably this produces an ordinary dependent lens, not an additive one
21 |   opt : (: inputTy) ->
22 |     (Const (paramCont x).Shp >< Const stateTy) =%> UC (paramCont x)
23 |
24 |
25 | -- todo upgrade to something like HasIO instead of IO?
26 | public export
27 | record Optimiser
28 |   (paramCont : AddCont)
29 |   (stateTy : Type)
30 |   where
31 |   constructor MkOptimiser
32 |   ||| Notably, this is an ordinary dependent lens, not an additive one
33 |   opt : (Const paramCont.Shp >< Const stateTy) =%> UC paramCont
34 |   ||| Procedure for initialising parameters
35 |   initParam : IO paramCont.Shp
36 |   ||| Procedure for initialising state
37 |   initState : IO stateTy
38 |
39 | public export
40 | (.fwd) : Optimiser p s -> (p.Shp, s) -> p.Shp
41 | (.fwd) (MkOptimiser opt _ _) = opt.fwd
42 |
43 | public export
44 | (.bwd) : (opt : Optimiser pCont stateTy) ->
45 |   (ps : (pCont.Shp, stateTy)) ->
46 |   (pCont.Pos (opt.fwd ps)) ->
47 |   (pCont.Shp, stateTy)
48 | (.bwd) (MkOptimiser opt _ _) = opt.bwd
49 |
50 | ||| From 8.1.3. "Can we compose optimisers?" of https://arxiv.org/abs/2403.13001
51 | ||| Not used yet
52 | public export
53 | composeParallel : Optimiser pCont s ->
54 |   Optimiser qCont t -> 
55 |   Optimiser (pCont >< qCont) (s, t)
56 | composeParallel (MkOptimiser o1 initP initS) (MkOptimiser o2 initQ initT) = MkOptimiser
57 |   (!% \((p, q), (s, t)) => ((o1.fwd (p, s), o2.fwd (q, t)) **
58 |     \(p', q') => let (pUpdated, sUpdated) = o1.bwd (p, s) p'
59 |                      (qUpdated, tUpdated) = o2.bwd (q, t) q'
60 |                  in ((pUpdated, qUpdated), (sUpdated, tUpdated))))
61 |   (pairIO initP initQ)
62 |   (pairIO initS initT)