0 | module NN.Training.Training
3 | import Data.Container.Additive as Additive
4 | import NN.Optimisers.Definition
5 | import NN.Optimisers.Instances
6 | import NN.Architectures.LossFunctions
40 | optimiseStep : {p, l, e : AddCont} -> InterfaceOnPositions l Num =>
41 | (f : p =%> (e >@ l)) ->
42 | (handleEffect : Costate (IO <!> e)) ->
43 | (optimiser : Optimiser p stateTy) ->
44 | Costate (IO <!> (Const (p.Shp, stateTy)))
45 | optimiseStep f handleEffect (MkOptimiser opt _ _) =
46 | let closeFunction : p =%> e
47 | closeFunction = f %>> (id >@ constantOne) %>> rightUnit
48 | in (IO <!> opt) %>> (IO <!> closeFunction) %>> handleEffect
52 | evalFw : {0 e : Cont} ->
53 | (f : a -> Ext e b) ->
54 | (handleEffect : Costate (IO <!> e)) ->
55 | Costate (IO <!> (Const2 a b))
56 | evalFw f handleEffect = toCostate $
\ps => do
57 | let (eInp <| outGivenEffect) = f ps
58 | e <- fromCostate handleEffect eInp
59 | pure $
outGivenEffect e
64 | optimise : {p, l, e : AddCont} -> InterfaceOnPositions l Num =>
65 | {default 100 printEvery : Nat} ->
66 | {default Nothing customInitParam : Maybe p.Shp} ->
67 | Show p.Shp => Show l.Shp => Show stateTy =>
68 | (f : p =%> e >@ l) ->
69 | (handleEffect : Costate (IO <!> e)) ->
70 | (opt : Optimiser p stateTy) ->
73 | optimise f handleEffect opt numSteps = do
74 | currentValue : p.Shp <- case customInitParam of
76 | Nothing => opt.initParam
77 | currentState <- opt.initState
78 | runActionUntilMaxSteps
80 | {printEvery=printEvery}
81 | (fromCostate $
optimiseStep f handleEffect opt)
84 | (currentValue, currentState)
85 | (fromCostate $
evalFw (f.fwd . opt.fwd) handleEffect)
92 | buildSupervisedLearningSystem : Num l.Shp => IsFlat l =>
93 | (f : ParaAddDLens x y) ->
94 | (loss : Loss y {l=l}) ->
95 | ((GetParam f) =%> (pushDown (x >< y)) >@ l)
96 | buildSupervisedLearningSystem (MkPara p f) loss =
97 | let rebracket : ((x >< y) >< p) =%> ((x >< p) >< y)
98 | rebracket = assocL %>> (id >< swap) %>> assocR
99 | in pushIntoContinuation {d=x><y} (rebracket %>> (f >< id) %>> loss)
102 | namespace WithEffect
110 | totalLoss : Show l.Shp => Num l.Shp =>
111 | (f : ParaAddDLens x (e >@ y)) ->
112 | (loss : Loss y {l=l}) ->
113 | (p : (GetParam f).Shp) ->
114 | (handleEffect : Costate (IO <!> e)) ->
115 | Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
116 | totalLoss (MkPara pCont f) loss p handleEffect = toCostate $
\testData => do
117 | let evalFWithLoss : (x.Shp, y.Shp) -> IO l.Shp
118 | evalFWithLoss (x, yTrue) = do
119 | yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
120 | pure $
loss.fwd (yPred, yTrue)
122 | losses <- traverse evalFWithLoss testData
123 | pure $
Prelude.sum losses
126 | averageLoss : {n : Nat} ->
127 | Show l.Shp => Num l.Shp => Fractional l.Shp => Cast Nat l.Shp =>
128 | (f : ParaAddDLens x (e >@ y)) ->
129 | (loss : Loss y {l=l}) ->
130 | (p : (GetParam f).Shp) ->
131 | (handleEffect : Costate (IO <!> e)) ->
132 | Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
133 | averageLoss f loss p handleEffect = toCostate $
\testData => do
134 | lossSum <- fromCostate (totalLoss f loss p handleEffect) testData
135 | pure (lossSum / cast n)
139 | evalWithLoss : Show x.Shp => Show y.Shp => Show l.Shp =>
140 | (f : ParaAddDLens x (e >@ y)) ->
141 | (loss : Loss y {l=l}) ->
142 | (p : (GetParam f).Shp) ->
143 | (handleEffect : Costate (IO <!> e)) ->
144 | Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) Unit))
145 | evalWithLoss (MkPara pCont f) loss p handleEffect = toCostate $
\testData => do
146 | let evalFWithLoss : (x.Shp, y.Shp) -> IO ()
147 | evalFWithLoss (x, yTrue) = do
148 | yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
149 | let lossVal = loss.fwd (yPred, yTrue)
150 | putStrLn "Input: \{show x}, Predicted: \{show yPred}, Loss: \{show lossVal}"
151 | _ <- traverse evalFWithLoss testData
156 | eval : Show x.Shp => Show y.Shp =>
157 | (f : ParaAddDLens x (e >@ y)) ->
158 | (p : (GetParam f).Shp) ->
159 | (handleEffect : Costate (IO <!> e)) ->
160 | Costate (IO <!> (Const2 (Vect n x.Shp) Unit))
161 | eval (MkPara _ f) p handleEffect = toCostate $
\testData => do
162 | let evalF : x.Shp -> IO ()
164 | yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
165 | putStrLn "Input: \{show x}, Predicted: \{show yPred}"
166 | _ <- traverse evalF testData
169 | namespace WithoutEffect
171 | trivialEffect : {y : AddCont} ->
172 | ParaAddDLens x y -> ParaAddDLens x (Scalar >@ y)
173 | trivialEffect (MkPara p f) = MkPara p
174 | (f %>> leftUnitInv)
177 | handleTrivial : Costate (IO <!> Additive.Object.Instances.Scalar)
178 | handleTrivial = toCostate $
\() => pure ()
182 | eval : {y : AddCont} -> Show x.Shp => Show y.Shp =>
183 | (f : ParaAddDLens x y) ->
184 | (p : (GetParam f).Shp) ->
185 | Costate (IO <!> (Const2 (Vect n x.Shp) Unit))
186 | eval (MkPara pCont f) p
187 | = eval {e=Scalar} (MkPara pCont (f %>> leftUnitInv)) p handleTrivial
190 | averageLoss : {y : AddCont} -> {n : Nat} ->
191 | Show l.Shp => Num l.Shp => Fractional l.Shp => Cast Nat l.Shp =>
192 | (f : ParaAddDLens x y) ->
193 | (loss : Loss y {l=l}) ->
194 | (p : (GetParam f).Shp) ->
195 | Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
196 | averageLoss (MkPara pCont f) loss p = averageLoss {e=Scalar}
197 | (MkPara pCont (f %>> leftUnitInv))