0 | module NN.Architectures.Transformer.Attention
  1 |
  2 | import Data.Tensor
  3 | import Data.Para
  4 | import NN.Architectures.Softargmax
  5 |
  6 | ||| Generalised form of attention
  7 | public export
  8 | crossAttention : {a : Type} -> Num a =>
  9 |   {inputStructure, crossStructure, features : Axis} ->
 10 |   (acif : NewAxisConsistent inputStructure [features]) =>
 11 |   (accf : NewAxisConsistent crossStructure [features]) =>
 12 |   (acci : NewAxisConsistent crossStructure [inputStructure]) =>
 13 |   TensorMonoid inputStructure.cont => TensorMonoid features.cont =>
 14 |   (allAlg : AllAlgebra [inputStructure, features] a) =>
 15 |   {default id causalMask : Tensor [crossStructure, inputStructure] a ->
 16 |                            Tensor [crossStructure, inputStructure] a} ->
 17 |   (softargmax : Tensor [inputStructure] a ->
 18 |                 Tensor [inputStructure] a) ->
 19 |   (q, v : Tensor [inputStructure, features] a) ->
 20 |   (k : Tensor [crossStructure, features] a) ->
 21 |   Tensor [crossStructure, features] a
 22 | crossAttention {allAlg=Cons {rest=xx}, causalMask} softargmax q v k =
 23 |   let attentionMatrix : Tensor [crossStructure, inputStructure] a
 24 |       attentionMatrix = q `matrixMatrixProduct` k
 25 |   in (softargmax <-$> (causalMask attentionMatrix)) `matMul` v
 26 |
 27 | ||| Self-attention is cross-attention where inputStructure = crossStructure
 28 | public export
 29 | selfAttention : {a : Type} -> Num a =>
 30 |   {inputStructure, features : Axis} ->
 31 |   NewAxisConsistent inputStructure [features] =>
 32 |   (TensorMonoid inputStructure.cont) =>
 33 |   (TensorMonoid features.cont) =>
 34 |   (allAlg : AllAlgebra [inputStructure, features] a) =>
 35 |   {default id causalMask : Tensor [inputStructure, inputStructure] a ->
 36 |                            Tensor [inputStructure, inputStructure] a} ->
 37 |   (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
 38 |   (q, v, k : Tensor [inputStructure, features] a) ->
 39 |   Tensor [inputStructure, features] a
 40 | selfAttention {causalMask} = crossAttention {causalMask}
 41 |
 42 | ||| Data structure for holding parameters of self-attention
 43 | public export
 44 | record SelfAttentionParams (features : Axis) (a : Type) where
 45 |   constructor MkSAParams
 46 |   queryMatParam : Tensor [features, features] a
 47 |   valueMatParam : Tensor [features, features] a
 48 |   keyMatParam : Tensor [features, features] a
 49 |
 50 | ||| Forward pass of self-attention, from input
 51 | public export
 52 | SAImpl : {a : Type} -> Num a =>
 53 |   {inputStructure, features : Axis} ->
 54 |   (ac : NewAxisConsistent inputStructure [features]) =>
 55 |   (TensorMonoid inputStructure.cont) =>
 56 |   (TensorMonoid features.cont) =>
 57 |   (allAlg : AllAlgebra [inputStructure, features] a) =>
 58 |   {default id causalMask : Tensor [inputStructure, inputStructure] a ->
 59 |                            Tensor [inputStructure, inputStructure] a} ->
 60 |   (softargmax : Tensor [inputStructure] a -> Tensor [inputStructure] a) ->
 61 |   DPair (Tensor [inputStructure, features] a)
 62 |         (const (SelfAttentionParams features a)) ->
 63 |   Tensor [inputStructure, features] a
 64 | SAImpl {allAlg = Cons} {causalMask} softargmax (input ** (MkSAParams queryMat valueMat keyMat))
 65 |   = let queries = queryMat `matrixMatrixProduct` input
 66 |         keys = keyMat `matrixMatrixProduct` input
 67 |         values = valueMat `matrixMatrixProduct` input
 68 |     in selfAttention {causalMask} softargmax queries values keys
 69 |
 70 | ||| Self-attention as a parametric map
 71 | public export
 72 | SelfAttention : {a : Type} -> Num a =>
 73 |   {inputStructure, features : Axis} ->
 74 |   NewAxisConsistent inputStructure [features] =>
 75 |   (TensorMonoid inputStructure.cont) => (TensorMonoid features.cont) =>
 76 |   (allAlg : AllAlgebra [inputStructure, features] a) =>
 77 |   {default id causalMask : Tensor [inputStructure, inputStructure] a ->
 78 |                            Tensor [inputStructure, inputStructure] a} ->
 79 |   (softargmax : Tensor [inputStructure] a ->
 80 |                 Tensor [inputStructure] a) ->
 81 |   Tensor [inputStructure, features] a -\-> Tensor [inputStructure, features] a
 82 | SelfAttention {causalMask} softargmax = MkPara
 83 |   (const (SelfAttentionParams features a))
 84 |   (SAImpl {causalMask} softargmax)
 85 |
 86 | -- public export
 87 | -- SelfAttentionParams : (features : Nat) -> (a : Type) -> Type
 88 | -- SelfAttentionParams features a = CSelfAttentionParams (Vect features) a
 89 |
 90 | public export
 91 | causalMask : {a : Type} -> Num a =>
 92 |   {c : Axis} -> Exp a =>
 93 |   InterfaceOnPositions c.cont MOrd =>
 94 |   TensorMonoid c.cont =>
 95 |   Tensor [c, c] a -> Tensor [c, c] a
 96 | causalMask attentionMatrix =
 97 |   let contShape : c.cont.Shp
 98 |       contShape = shapeExt (shapeExt (GetT attentionMatrix))
 99 |   in maskedFill attentionMatrix (not <$> cTriBool contShape) minusInfinity