0 | module NN.Architectures.Transformer.Definition
5 | import NN.Architectures.Affine
6 | import NN.Architectures.Residual
7 | import NN.Architectures.MLP
8 | import NN.Architectures.Activations
9 | import NN.Architectures.Transformer.Attention
10 | import NN.Architectures.Utils
15 | Transformer : {a : Type} -> Num a => Ord a =>
16 | {inputStructure, features : Axis} ->
17 | (ac : NewAxisConsistent inputStructure [features]) =>
18 | (TensorMonoid inputStructure.cont) =>
19 | (TensorMonoid features.cont) =>
20 | (allAlg : AllAlgebra [inputStructure, features] a) =>
21 | {default id causalMask : Tensor [inputStructure, inputStructure] a ->
22 | Tensor [inputStructure, inputStructure] a} ->
23 | (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
24 | Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
25 | Transformer {allAlg = Cons} softargmax
26 | = composePara (addResidual (SelfAttention softargmax)) (addResidual ffnet)
28 | ffnet : Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
29 | ffnet = paraMapFirstAxis (multiLayerPerceptron {a=a} {ieva=features} 2 (trivialParam relu) {lastLayerActivation=False})