0 | module NN.Architectures.RNN
 1 |
 2 | import Data.Tensor
 3 | import Data.Para
 4 |
 5 | ||| Defines the type of a RNN cell as a parametric map
 6 | ||| @ x the type of the input
 7 | ||| @ s the type of the state
 8 | ||| @ y the type of the output
 9 | public export
10 | RNNCell : (x, s, y : Type) -> Type
11 | RNNCell x s y = (x, s) -\-> (y, s)
12 |
13 | ||| Defines the type of the unrolled RNN as a parametric map
14 | ||| @ n the number of unroll steps
15 | public export
16 | RNN : (x, s, y : Type) -> (n : Nat) ->Type
17 | RNN x s y n = (Vect n x, s) -\-> (Vect n y, s)
18 |
19 | ||| Given a rnn cell, implement the full RNN by iterating that cell
20 | ||| Helper function for `RNNPara`
21 | public export
22 | RNNImpl : (cell : DPair (x, s) (const p) -> (y, s)) ->
23 |   DPair (Vect n x, s) (const p) -> (Vect n y, s)
24 | RNNImpl _ (([], s) ** _= ([], s)
25 | RNNImpl cell (((x :: xs), s) ** p=
26 |   let (y, s') = cell ((x, s) ** p)
27 |       (ys, s'') = RNNImpl cell ((xs, s') ** p)
28 |   in (y :: ys, s'')
29 |
30 | ||| Parametric map for the full RNN
31 | public export
32 | RNNPara : (cell : RNNCell x s y) ->
33 |   IsNotDependent cell =>
34 |   RNN x s y n
35 | RNNPara (MkPara (const p) cell) @{MkNonDep p cell} = MkPara
36 |   (\_ => p)
37 |   (RNNImpl cell)
38 |
39 |
40 | public export
41 | runRNN : (rnn : RNN x s y n) ->
42 |   (xs : Vect n x) ->
43 |   (initialState : s) ->
44 |   (p : Param rnn (xs, initialState)) ->
45 |   Vect n y
46 | runRNN rnn xs initialState p = fst $ Run rnn (xs, initialState) p
47 |
48 | public export
49 | exampleRNN : RNNCell Double Double Double
50 | exampleRNN = MkPara (\_ => ()) (\((x, s) ** ()) => (if s > 4 then x else 0, s + 1))
51 |
52 | public export
53 | exampleInput : Vect 10 Double
54 | exampleInput = [1,2,3,4,5,6,7,8,9,10]
55 |
56 | public export
57 | exampleOutput : Vect 10 Double
58 | exampleOutput = runRNN (RNNPara exampleRNN) exampleInput 0 ()
59 |
60 |
61 |