0 | module NN.Training.Examples.LinearRegression
5 | import Data.Container.Additive
6 | import NN.Optimisers.Definition
7 | import NN.Optimisers.Instances
8 | import NN.Training.Training
9 | import NN.Training.DataLoader
11 | import Data.Autodiff.Ops
12 | import Control.Monad.Identity
13 | import NN.Architectures.LossFunctions
19 | exampleInputs : Vect 5 Double
20 | exampleInputs = [1, 2, 3, 4, 5]
23 | groundTruth : Double -> Double
24 | groundTruth x = 2 * x + 1
27 | linearRegressionDataLoader : Monad m => m (DataLoader Double Double)
28 | linearRegressionDataLoader = makeDataLoader exampleInputs (pure . groundTruth)
31 | linearRegression : (f : ParaAddDLens (Const Double) (Const Double)) ->
32 | Neg (GetParam f).Shp => Fractional (GetParam f).Shp =>
33 | Sqrt (GetParam f).Shp =>
34 | Random (GetParam f).Shp =>
35 | FromDouble (GetParam f).Shp => Show (GetParam f).Shp =>
36 | (isFlat : IsFlat (GetParam f)) =>
39 | linearRegression f@(MkPara (MkAddCont (Const p)) _)
40 | {isFlat = MkIsFlat p @{mon}} numSteps = do
41 | trainData <- linearRegressionDataLoader
42 | testDataLoader <- makeDataLoader [20, 50, 100] (pure . groundTruth)
43 | pTrained <- fst <$> optimise
44 | {l=Const Double, e=pushDown (Const Double >< Const Double)}
45 | (buildSupervisedLearningSystem f SquaredDifference)
46 | (handleData trainData)
47 | (GDMomentum {pType=(GetParam f).Shp})
49 | fromCostate (eval f pTrained) (snd $
inputs testDataLoader)
50 | avgLoss <- fromCostate (averageLoss f SquaredDifference pTrained) (dataset testDataLoader)
51 | putStrLn "Average loss: \{show avgLoss}"