0 | module NN.Training.DataLoader
 1 |
 2 | import Data.Vect
 3 |
 4 | import Data.Container.Additive
 5 |
 6 | import Control.Monad.Distribution
 7 | import Control.Monad.Sample.Definition
 8 | import Control.Monad.Sample.Instances
 9 |
10 |
11 | public export
12 | record DataLoader (input : Type) (output : Type) where
13 |   constructor MkDataLoader
14 |   datasetSize : Nat
15 |   dataset : Vect datasetSize (input, output)
16 |   {auto isSucc : IsSucc datasetSize}
17 |
18 | public export
19 | inputs : DataLoader input output -> (n ** Vect n input)
20 | inputs (MkDataLoader datasetSize dataset) = (datasetSize ** fst <$> dataset)
21 |
22 | public export
23 | makeDataLoader : Monad m => {datasetSize : Nat} -> IsSucc datasetSize =>
24 |   (inputs : Vect datasetSize input) ->
25 |   (groundTruthFn : input -> m output) ->
26 |   m (DataLoader input output)
27 | makeDataLoader xs groundTruthFn = do 
28 |   ys <- traverse {f=m} {t=Vect datasetSize} groundTruthFn xs
29 |   pure $ MkDataLoader datasetSize (zip xs ys)
30 |
31 | ||| Samples a single item from the dataset
32 | public export
33 | sample : DataLoader input output -> IO (input, output)
34 | sample (MkDataLoader datasetSize dataset) = do
35 |   n <- sample (uniform {i=datasetSize})
36 |   pure (index n dataset)
37 |
38 | public export
39 | handleData : DataLoader x y ->
40 |   Costate (IO <!> pushDown (x, y))
41 | handleData dataLoader = toCostate $ \() => sample dataLoader <&> pure
42 |