0 | module NN.Architectures.Affine
 1 |
 2 | import Data.Tensor
 3 | import Data.Para
 4 |
 5 | -- This is often called a 'linear layer', but really it is affine because of the bias
 6 |
 7 | public export
 8 | record AffineLayerParams
 9 |   (x, y : Axis)
10 |   {auto ac : NewAxisConsistent y [x]}
11 |   (a : Type) where
12 |   constructor MkParams
13 |   weights : Tensor [y, x] a
14 |   bias : Tensor [y] a
15 |
16 | public export
17 | affineImpl : {x, y : Axis} ->
18 |   NewAxisConsistent y [x] =>
19 |   Num a =>
20 |   AllAlgebra [x] a =>
21 |   TensorMonoid x.cont => TensorMonoid y.cont =>
22 |   DPair (Tensor [x] a) (const (AffineLayerParams x y a)) -> Tensor [y] a
23 | affineImpl (input ** (MkParams weights bias))
24 |   = matrixVectorProduct weights input + bias
25 |
26 | public export
27 | affinePara : {x, y : Axis} -> {a : Type} -> Num a =>
28 |   NewAxisConsistent y [x] =>
29 |   AllAlgebra [x] a =>
30 |   TensorMonoid x.cont => TensorMonoid y.cont =>
31 |   Tensor [x] a -\-> Tensor [y] a
32 | affinePara = MkPara
33 |   (const (AffineLayerParams x y a))
34 |   affineImpl
35 |