0 | module NN.Optimisers.Definition
3 | import Data.Container.Additive
15 | record DOptimiser {inputTy : Type}
16 | (paramCont : inputTy -> AddCont)
19 | constructor MkDOptimiser
21 | opt : (x : inputTy) ->
22 | (Const (paramCont x).Shp >< Const stateTy) =%> UC (paramCont x)
28 | (paramCont : AddCont)
31 | constructor MkOptimiser
33 | opt : (Const paramCont.Shp >< Const stateTy) =%> UC paramCont
35 | initParam : IO paramCont.Shp
37 | initState : IO stateTy
40 | (.fwd) : Optimiser p s -> (p.Shp, s) -> p.Shp
41 | (.fwd) (MkOptimiser opt _ _) = opt.fwd
44 | (.bwd) : (opt : Optimiser pCont stateTy) ->
45 | (ps : (pCont.Shp, stateTy)) ->
46 | (pCont.Pos (opt.fwd ps)) ->
47 | (pCont.Shp, stateTy)
48 | (.bwd) (MkOptimiser opt _ _) = opt.bwd
53 | composeParallel : Optimiser pCont s ->
54 | Optimiser qCont t ->
55 | Optimiser (pCont >< qCont) (s, t)
56 | composeParallel (MkOptimiser o1 initP initS) (MkOptimiser o2 initQ initT) = MkOptimiser
57 | (!% \((p, q), (s, t)) => (
(o1.fwd (p, s), o2.fwd (q, t)) **
58 | \(p', q') => let (pUpdated, sUpdated) = o1.bwd (p, s) p'
59 | (qUpdated, tUpdated) = o2.bwd (q, t) q'
60 | in ((pUpdated, qUpdated), (sUpdated, tUpdated)))
)
61 | (pairIO initP initQ)
62 | (pairIO initS initT)