0 | module NN.Training.DataLoader
4 | import Data.Container.Additive
6 | import Control.Monad.Distribution
7 | import Control.Monad.Sample.Definition
8 | import Control.Monad.Sample.Instances
12 | record DataLoader (input : Type) (output : Type) where
13 | constructor MkDataLoader
15 | dataset : Vect datasetSize (input, output)
16 | {auto isSucc : IsSucc datasetSize}
19 | inputs : DataLoader input output -> (
n ** Vect n input)
20 | inputs (MkDataLoader datasetSize dataset) = (
datasetSize ** fst <$> dataset)
23 | makeDataLoader : Monad m => {datasetSize : Nat} -> IsSucc datasetSize =>
24 | (inputs : Vect datasetSize input) ->
25 | (groundTruthFn : input -> m output) ->
26 | m (DataLoader input output)
27 | makeDataLoader xs groundTruthFn = do
28 | ys <- traverse {f=m} {t=Vect datasetSize} groundTruthFn xs
29 | pure $
MkDataLoader datasetSize (zip xs ys)
33 | sample : DataLoader input output -> IO (input, output)
34 | sample (MkDataLoader datasetSize dataset) = do
35 | n <- sample (uniform {i=datasetSize})
36 | pure (index n dataset)
39 | handleData : DataLoader x y ->
40 | Costate (IO <!> pushDown (x, y))
41 | handleData dataLoader = toCostate $
\() => sample dataLoader <&> pure