0 | module NN.Architectures.Transformer.Attention
4 | import NN.Architectures.Softargmax
8 | crossAttention : {a : Type} -> Num a =>
9 | {inputStructure, crossStructure, features : Axis} ->
10 | (acif : NewAxisConsistent inputStructure [features]) =>
11 | (accf : NewAxisConsistent crossStructure [features]) =>
12 | (acci : NewAxisConsistent crossStructure [inputStructure]) =>
13 | TensorMonoid inputStructure.cont => TensorMonoid features.cont =>
14 | (allAlg : AllAlgebra [inputStructure, features] a) =>
15 | {default id causalMask : Tensor [crossStructure, inputStructure] a ->
16 | Tensor [crossStructure, inputStructure] a} ->
17 | (softargmax : Tensor [inputStructure] a ->
18 | Tensor [inputStructure] a) ->
19 | (q, v : Tensor [inputStructure, features] a) ->
20 | (k : Tensor [crossStructure, features] a) ->
21 | Tensor [crossStructure, features] a
22 | crossAttention {allAlg=Cons {rest=xx}, causalMask} softargmax q v k =
23 | let attentionMatrix : Tensor [crossStructure, inputStructure] a
24 | attentionMatrix = q `matrixMatrixProduct` k
25 | in (softargmax <-$> (causalMask attentionMatrix)) `matMul` v
29 | selfAttention : {a : Type} -> Num a =>
30 | {inputStructure, features : Axis} ->
31 | NewAxisConsistent inputStructure [features] =>
32 | (TensorMonoid inputStructure.cont) =>
33 | (TensorMonoid features.cont) =>
34 | (allAlg : AllAlgebra [inputStructure, features] a) =>
35 | {default id causalMask : Tensor [inputStructure, inputStructure] a ->
36 | Tensor [inputStructure, inputStructure] a} ->
37 | (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
38 | (q, v, k : Tensor [inputStructure, features] a) ->
39 | Tensor [inputStructure, features] a
40 | selfAttention {causalMask} = crossAttention {causalMask}
44 | record SelfAttentionParams (features : Axis) (a : Type) where
45 | constructor MkSAParams
46 | queryMatParam : Tensor [features, features] a
47 | valueMatParam : Tensor [features, features] a
48 | keyMatParam : Tensor [features, features] a
52 | SAImpl : {a : Type} -> Num a =>
53 | {inputStructure, features : Axis} ->
54 | (ac : NewAxisConsistent inputStructure [features]) =>
55 | (TensorMonoid inputStructure.cont) =>
56 | (TensorMonoid features.cont) =>
57 | (allAlg : AllAlgebra [inputStructure, features] a) =>
58 | {default id causalMask : Tensor [inputStructure, inputStructure] a ->
59 | Tensor [inputStructure, inputStructure] a} ->
60 | (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
61 | DPair (Tensor [inputStructure, features] a)
62 | (const (SelfAttentionParams features a)) ->
63 | Tensor [inputStructure, features] a
64 | SAImpl {allAlg = Cons} {causalMask} softargmax (
input ** (MkSAParams queryMat valueMat keyMat))
65 | = let queries = queryMat `matrixMatrixProduct` input
66 | keys = keyMat `matrixMatrixProduct` input
67 | values = valueMat `matrixMatrixProduct` input
68 | in selfAttention {causalMask} softargmax queries values keys
72 | SelfAttention : {a : Type} -> Num a =>
73 | {inputStructure, features : Axis} ->
74 | NewAxisConsistent inputStructure [features] =>
75 | (TensorMonoid inputStructure.cont) => (TensorMonoid features.cont) =>
76 | (allAlg : AllAlgebra [inputStructure, features] a) =>
77 | {default id causalMask : Tensor [inputStructure, inputStructure] a ->
78 | Tensor [inputStructure, inputStructure] a} ->
79 | (softargmax : Tensor [inputStructure] a ->
80 | Tensor [inputStructure] a) ->
81 | Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
82 | SelfAttention {causalMask} softargmax = MkPara
83 | (const (SelfAttentionParams features a))
84 | (SAImpl {causalMask} softargmax)
91 | causalMask : {a : Type} -> Num a =>
92 | {c : Axis} -> Exp a =>
93 | InterfaceOnPositions c.cont MOrd =>
94 | TensorMonoid c.cont =>
95 | Tensor [c, c] a -> Tensor [c, c] a
96 | causalMask attentionMatrix =
97 | let contShape : c.cont.Shp
98 | contShape = shapeExt (shapeExt (GetT attentionMatrix))
99 | in maskedFill attentionMatrix (not <$> cTriBool contShape) minusInfinity