0 | module NN.Architectures.Transformer.Definition
 1 |
 2 | import Data.Tensor
 3 | import Data.Para
 4 |
 5 | import NN.Architectures.Affine
 6 | import NN.Architectures.Residual
 7 | import NN.Architectures.MLP
 8 | import NN.Architectures.Activations
 9 | import NN.Architectures.Transformer.Attention
10 | import NN.Architectures.Utils
11 |
12 | ||| Single-head transformer layer
13 | ||| Only missing layernorm, otherwise a complete definition
14 | public export
15 | Transformer : {a : Type} -> Num a => Ord a =>
16 |   {inputStructure, features : Axis} ->
17 |   (ac : NewAxisConsistent inputStructure [features]) =>
18 |   (TensorMonoid inputStructure.cont) =>
19 |   (TensorMonoid features.cont) =>
20 |   (allAlg : AllAlgebra [inputStructure, features] a) =>
21 |   {default id causalMask : Tensor [inputStructure, inputStructure] a ->
22 |                            Tensor [inputStructure, inputStructure] a} ->
23 |   (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
24 |   Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
25 | Transformer {allAlg = Cons} softargmax
26 |   = composePara (addResidual (SelfAttention softargmax)) (addResidual ffnet)
27 |     where
28 |       ffnet : Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
29 |       ffnet = paraMapFirstAxis (multiLayerPerceptron {a=a} {ieva=features} 2 (trivialParam relu) {lastLayerActivation=False})
30 |