0 | module NN.Training.Training
  1 |
  2 | import Data.Tensor
  3 | import Data.Container.Additive as Additive
  4 | import NN.Optimisers.Definition
  5 | import NN.Optimisers.Instances
  6 | import NN.Architectures.LossFunctions
  7 |
  8 | import NN.Utils
  9 | import Data.Para
 10 |
 11 | {-------------------------------------------------------------------------------
 12 | {-------------------------------------------------------------------------------
 13 | TODO update this in light of effects
 14 |
 15 | This file defines functions which perform pure optimisation:
 16 | optimisation of a differentiable function `f : p -> x` (here modelled as a lens `f : p =%> x` via some optimiser such as gradient descent. 
 17 | Here no "loss function" or "input-output pairs" are needed, just a function to optimise.
 18 |
 19 | This file provides functionality for creating, and turning supervised learning problems into pure optimisation problems, via a function which takes:
 20 | a) a parametric lens `f : x >< p =%>> y`
 21 | b) a loss function `loss : (y, y) =%> l`
 22 | c) input-output pairs `IO (x.Shp, y.Shp)`
 23 | and composes them to produce an optimisation problem `f : p =%> l` which the above described functions can consume.
 24 |
 25 | Notably, only a *non-dependent* supervised-learning problem can be turned into a pure optimisation one. If the parameter space depends on the input, then learning becomes its own thing.
 26 |
 27 |
 28 | todo using Hom-version of optimisation becomes problematic if we either
 29 | a) have the dependency of the parameter on the input
 30 | b) use monadic lenses
 31 |
 32 | -------------------------------------------------------------------------------}
 33 | -------------------------------------------------------------------------------}
 34 |
 35 | ||| Performs a single step of optimisation of some differentiable function
 36 | ||| `f : p -> x`, where we additionally need to handle some effect `e`
 37 | ||| Uses a potentially stateful optimiser, and returns an update of its
 38 | ||| parameters and state
 39 | public export
 40 | optimiseStep : {p, l, e : AddCont} -> InterfaceOnPositions l Num =>
 41 |   (f : p =%> (e >@ l)) ->
 42 |   (handleEffect : Costate (IO <!> e)) ->
 43 |   (optimiser : Optimiser p stateTy) ->
 44 |   Costate (IO <!> (Const (p.Shp, stateTy)))
 45 | optimiseStep f handleEffect (MkOptimiser opt _ _) = 
 46 |   let closeFunction : p =%> e 
 47 |       closeFunction = f %>> (id >@ constantOne) %>> rightUnit
 48 |   in (IO <!> opt) %>> (IO <!> closeFunction) %>> handleEffect
 49 |
 50 | ||| Evaluates a the forward pass of some effectful lens
 51 | public export
 52 | evalFw : {0 e : Cont} ->
 53 |   (f : a -> Ext e b) ->
 54 |   (handleEffect : Costate (IO <!> e)) ->
 55 |   Costate (IO <!> (Const2 a b))
 56 | evalFw f handleEffect = toCostate $ \ps => do
 57 |   let (eInp <| outGivenEffect) = f ps 
 58 |   e <- fromCostate handleEffect eInp
 59 |   pure $ outGivenEffect e
 60 |
 61 | ||| Iterates `optimiseStep` `numSteps` times, and logs the progress to the 
 62 | ||| console
 63 | public export
 64 | optimise : {p, l, e : AddCont} -> InterfaceOnPositions l Num =>
 65 |   {default 100 printEvery : Nat} ->
 66 |   {default Nothing customInitParam : Maybe p.Shp} ->
 67 |   Show p.Shp => Show l.Shp => Show stateTy =>
 68 |   (f : p =%> e >@ l) ->
 69 |   (handleEffect : Costate (IO <!> e)) ->
 70 |   (opt : Optimiser p stateTy) ->
 71 |   (numSteps : Nat) ->
 72 |   IO (p.Shp, stateTy)
 73 | optimise f handleEffect opt numSteps = do
 74 |   currentValue : p.Shp <- case customInitParam of
 75 |     Just p => pure p
 76 |     Nothing => opt.initParam
 77 |   currentState <- opt.initState
 78 |   runActionUntilMaxSteps
 79 |     {l=l.Shp}
 80 |     {printEvery=printEvery}
 81 |     (fromCostate $ optimiseStep f handleEffect opt)
 82 |     numSteps
 83 |     0
 84 |     (currentValue, currentState)
 85 |     (fromCostate $ evalFw (f.fwd . opt.fwd) handleEffect)
 86 |
 87 | ||| Given
 88 | ||| a) a parametric lens `f : x >< p =%> y`
 89 | ||| b) a loss function `loss : y >< y =%> l`
 90 | ||| builds an effectful lens `p =%> l`
 91 | public export
 92 | buildSupervisedLearningSystem : Num l.Shp => IsFlat l =>
 93 |   (f : ParaAddDLens x y) ->
 94 |   (loss : Loss y {l=l}) ->
 95 |   ((GetParam f) =%> (pushDown (x >< y)) >@ l)
 96 | buildSupervisedLearningSystem (MkPara p f) loss =
 97 |   let rebracket : ((x >< y) >< p) =%> ((x >< p) >< y)
 98 |       rebracket = assocL %>> (id >< swap) %>> assocR
 99 |   in pushIntoContinuation {d=x><y} (rebracket %>> (f >< id) %>> loss)
100 |
101 |
102 | namespace WithEffect
103 |   ||| When it comes to effects which involve sampling, where the 'correct' answer
104 |   ||| is stored in the test data, there are different ways of evaluating the loss
105 |   ||| One is to use that correct label to force the correct branch to run, but
106 |   ||| that is impossible with the current type signature
107 |   ||| Instead, we opt out for the more accurate method of sampling during loss
108 |   ||| evaluation
109 |   public export
110 |   totalLoss : Show l.Shp => Num l.Shp =>
111 |     (f : ParaAddDLens x (e >@ y)) ->
112 |     (loss : Loss y {l=l}) ->
113 |     (p : (GetParam f).Shp) ->
114 |     (handleEffect : Costate (IO <!> e)) ->
115 |     Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
116 |   totalLoss (MkPara pCont f) loss p handleEffect = toCostate $ \testData => do
117 |     let evalFWithLoss : (x.Shp, y.Shp) -> IO l.Shp
118 |         evalFWithLoss (x, yTrue) = do
119 |           yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
120 |           pure $ loss.fwd (yPred, yTrue)
121 |           -- putStrLn "Input: \{show x}, Predicted: \{show yPred}, Loss: \{show lossVal}"
122 |     losses <- traverse evalFWithLoss testData
123 |     pure $ Prelude.sum losses
124 |   
125 |   public export
126 |   averageLoss :  {n : Nat} ->
127 |     Show l.Shp => Num l.Shp => Fractional l.Shp => Cast Nat l.Shp =>
128 |     (f : ParaAddDLens x (e >@ y)) ->
129 |     (loss : Loss y {l=l}) ->
130 |     (p : (GetParam f).Shp) ->
131 |     (handleEffect : Costate (IO <!> e)) ->
132 |     Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
133 |   averageLoss f loss p handleEffect = toCostate $ \testData => do
134 |     lossSum <- fromCostate (totalLoss f loss p handleEffect) testData
135 |     pure (lossSum / cast n)
136 |   
137 |   ||| Eval a model and loss with specific parameters, in the presence of an effect
138 |   public export
139 |   evalWithLoss : Show x.Shp => Show y.Shp => Show l.Shp =>
140 |     (f : ParaAddDLens x (e >@ y)) ->
141 |     (loss : Loss y {l=l}) ->
142 |     (p : (GetParam f).Shp) ->
143 |     (handleEffect : Costate (IO <!> e)) ->
144 |     Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) Unit))
145 |   evalWithLoss (MkPara pCont f) loss p handleEffect = toCostate $ \testData => do 
146 |     let evalFWithLoss : (x.Shp, y.Shp) -> IO ()
147 |         evalFWithLoss (x, yTrue) = do
148 |           yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
149 |           let lossVal = loss.fwd (yPred, yTrue)
150 |           putStrLn "Input: \{show x}, Predicted: \{show yPred}, Loss: \{show lossVal}"
151 |     _ <- traverse evalFWithLoss testData
152 |     pure ()
153 |   
154 |   ||| Eval a model with specific parameters, in the presence of an effect
155 |   public export
156 |   eval : Show x.Shp => Show y.Shp =>
157 |     (f : ParaAddDLens x (e >@ y)) ->
158 |     (p : (GetParam f).Shp) ->
159 |     (handleEffect : Costate (IO <!> e)) ->
160 |     Costate (IO <!> (Const2 (Vect n x.Shp) Unit))
161 |   eval (MkPara _ f) p handleEffect = toCostate $ \testData => do
162 |     let evalF : x.Shp -> IO ()
163 |         evalF x = do
164 |           yPred <- fromCostate (evalFw f.fwd handleEffect) (x, p)
165 |           putStrLn "Input: \{show x}, Predicted: \{show yPred}"
166 |     _ <- traverse evalF testData
167 |     pure ()
168 |
169 | namespace WithoutEffect
170 |   public export
171 |   trivialEffect : {y : AddCont} ->
172 |     ParaAddDLens x y -> ParaAddDLens x (Scalar >@ y)
173 |   trivialEffect (MkPara p f) = MkPara p
174 |     (f %>> leftUnitInv)
175 |
176 |   public export
177 |   handleTrivial : Costate (IO <!> Additive.Object.Instances.Scalar)
178 |   handleTrivial = toCostate $ \() => pure ()
179 |
180 |
181 |   public export
182 |   eval : {y : AddCont} -> Show x.Shp => Show y.Shp =>
183 |     (f : ParaAddDLens x y) ->
184 |     (p : (GetParam f).Shp) ->
185 |     Costate (IO <!> (Const2 (Vect n x.Shp) Unit))
186 |   eval (MkPara pCont f) p
187 |     = eval {e=Scalar} (MkPara pCont (f %>> leftUnitInv)) p handleTrivial
188 |
189 |   public export
190 |   averageLoss :  {y : AddCont} -> {n : Nat} ->
191 |     Show l.Shp => Num l.Shp => Fractional l.Shp => Cast Nat l.Shp =>
192 |     (f : ParaAddDLens x y) ->
193 |     (loss : Loss y {l=l}) ->
194 |     (p : (GetParam f).Shp) ->
195 |     Costate (IO <!> (Const2 (Vect n (x.Shp, y.Shp)) l.Shp))
196 |   averageLoss (MkPara pCont f) loss p = averageLoss {e=Scalar}
197 |     (MkPara pCont (f %>> leftUnitInv))
198 |     loss
199 |     p
200 |     handleTrivial
201 |   
202 |
203 | {-
204 | public export
205 | train : {x, y, l : AddCont} -> InterfaceOnPositions l Num => IsFlat l =>
206 |   {default 100 printEvery : Nat} ->
207 |   (f : ParaAddDLens x y) ->
208 |   Show (GetParam f).Shp => Num l.Shp =>
209 |   Show x.Shp => Show y.Shp => Show stateTy => Show l.Shp =>
210 |   {default Nothing initParam : Maybe (GetParam f).Shp} ->
211 |   (loss : (y >< y) =%> l) ->
212 |   (handleData : Costate (IO <!> (pushDown (x >< y)))) ->
213 |   (opt : Optimiser (GetParam f) stateTy) ->
214 |   (numSteps : Nat) ->
215 |   IO ((GetParam f).Shp, stateTy)
216 | train f loss handleData = optimise
217 |   {e=pushDown (x><y)}
218 |   {printEvery=printEvery}
219 |   {initParam=initParam}
220 |   (buildSupervisedLearningSystem f loss)
221 |   handleData
222 | -}
223 |
224 | {-
225 | -- todo write a variant of this with effects?
226 | public export
227 | debugPrint : {x, y : AddCont} ->
228 |   Show x.Shp => Show y.Shp =>
229 |   (name : String) ->
230 |   (f : ParaAddMLens {m=IO} x y) ->
231 |   Show (GetParam f).Shp =>
232 |   ParaAddMLens {m=IO} x y
233 | debugPrint name (MkPara pCont f) = MkPara
234 |   pCont
235 |   (!%%+ \(x, p) => do
236 |     putStrLn "--------------------------------"
237 |     putStrLn "\{name} input: \{show x}"
238 |     putStrLn "\{name} parameter: \{show p}"
239 |     (y ** ky) <- (%%!+ f) (x, p)
240 |     putStrLn "\{name} output: \{show y}"
241 |     putStrLn "--------------------------------"
242 |     pure (y ** ky))
243 |
244 | -- namespace Additive
245 | --   ||| Evaluates a the forward pass of some effectful lens
246 | --   public export
247 | --   evalFw : {0 a, e, b : AddCont} ->
248 | --     (f : a =%> (e >@ b)) ->
249 | --     (handleEffect : Costate (IO <!> e)) ->
250 | --     Costate (IO <!> (Const2 a.Shp b.Shp))
251 | --   evalFw f handleEffect = toCostate $ \ps => do
252 | --     let (eInp <| outGivenEffect) = f.fwd ps 
253 | --     e <- fromCostate handleEffect eInp
254 | --     pure $ outGivenEffect e
255 |