多头注意力 (MHA)

Open In Colab

这是论文《 Attention is All You Need 》中多头注意力的PyTorch教程/实现。该实现的灵感来自《带注释的 Transformer 》

这是使用基础 Transformer 和 MHA 进行 NLP 自回归的训练代码

这是一个训练简单 Transformer 的代码实现

24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

准备多头注意力

该部分执行线性变换,并将向量分割成给定数量的头以获得多头注意力。这用于查询向量。

33class PrepareForMultiHeadAttention(nn.Module):
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

线性层用于线性变换

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

注意力头数

49        self.heads = heads

每个头部中向量的维度数量

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

输入的形状为[seq_len, batch_size, d_model][batch_size, d_model] 。我们对最后一维应用线性变换,并将其分为多个头。

57        head_shape = x.shape[:-1]

线性变换

60        x = self.linear(x)

将最后一个维度分成多个头部

63        x = x.view(*head_shape, self.heads, self.d_k)

输出具有形状[seq_len, batch_size, heads, d_k][batch_size, heads, d_model]

66        return x

多头注意力模块

这将计算给出的keyvaluequery 向量缩放后的多头注意力。

简单来说,它会找到与查询 (Query) 匹配的键 (key),并获取这些键 (Key) 的值 (Value) 。

它使用查询和键的点积作为衡量它们之间匹配程度的指标。在进行之前,点积会乘以。这样做是为了避免当较大时,大的点积值导致 Softmax 操作输出非常小的梯度。

Softmax 是沿序列(或时间)轴计算的。

69class MultiHeadAttention(nn.Module):
  • heads 是注意力头的数量。
  • d_model 是向量querykeyvalue 中的特征数量。
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

每个头部的特征数量

99        self.d_k = d_model // heads

注意力头数

101        self.heads = heads

这些将对多头注意力的向量querykeyvalue 进行转换。

104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

在键( Key )的时间维度上进行注意力 Softmaxkey

109        self.softmax = nn.Softmax(dim=1)

输出层

112        self.output = nn.Linear(d_model, d_model)

Dropout

114        self.dropout = nn.Dropout(dropout_prob)

Softmax 之前的缩放系数

116        self.scale = 1 / math.sqrt(self.d_k)

存储注意力信息,以便在需要时用于记录或其他计算。

119        self.attn = None

计算 Qurey 和 Key 之间的分数

这种方法可以同样适用于其他变体,如相对注意力。

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

计算

129        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask 的形状为[seq_len_q, seq_len_k, batch_size] ,其中第一维是查询维度。如果查询维度等于,则会进行广播。

131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

所有的头部使用相同的掩码。

142        mask = mask.unsqueeze(-1)

生成的掩码形状为[seq_len_q, seq_len_k, batch_size, heads]

145        return mask

querykeyvalue 是存储查询向量集合的张量。它们的形状为[seq_len, batch_size, d_model]

mask 的形状为[seq_len, seq_len, batch_size]mask[i, j, b] 表示批次b ,在位置i 处查询是否有权访问位置j 处的键值对。

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):

querykeyvalue 的形状为[seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)

为注意力计算准备向量querykeyvalue 它们的形状将变为[seq_len, batch_size, heads, d_k]

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

计算注意力分数,这将得到一个形状为[seq_len, seq_len, batch_size, heads] 的张量。

176        scores = self.get_scores(query, key)

缩放分数

179        scores *= self.scale

应用掩码

182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))

对 Key 序列维度上的注意力进行操作,

187        attn = self.softmax(scores)

调试时保存注意力信息

190        tracker.debug('attn', attn)

应用 Dropout

193        attn = self.dropout(attn)

乘以数值

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

为其他计算保存注意力信息

200        self.attn = attn.detach()

连接多个头

203        x = x.reshape(seq_len, batch_size, -1)

输出层

206        return self.output(x)