多头注意力 (MHA)

Open In Colab

这是 P yTorch 中论文 “注意力就是你所需要的” 多头注意力的教程/实现。该实现的灵感来自带注释的变形金刚

以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码

这是一个训练简单变压器的实验实现

24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

为多头注意做好准备

该模块进行线性变换,并将向量拆分为给定数量的头部,以获得多头注意。这用于转换查询向量。

33class PrepareForMultiHeadAttention(nn.Module):
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

线性变换的线性层

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

头数

49        self.heads = heads

每个头部中以向量表示的维度数

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

输入的形状[seq_len, batch_size, d_model][batch_size, d_model] 。我们将线性变换应用于最后一个维度,然后将其拆分为头部。

57        head_shape = x.shape[:-1]

线性变换

60        x = self.linear(x)

将最后一个维度拆分成头部

63        x = x.view(*head_shape, self.heads, self.d_k)

输出具有形状[seq_len, batch_size, heads, d_k][batch_size, heads, d_model]

66        return x

多头注意模块

这将计算给定keyvalue 向量的缩放多头注意query 力。

简单来说,它会找到与查询匹配的键,并获取这些键的值。

它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前,先按比例缩放。这样做是为了避免较大的点积值导致 softmax 在较大时给出非常小的梯度。

Softmax 是沿序列(或时间)的轴计算的。

69class MultiHeadAttention(nn.Module):
  • heads 是头的数量。
  • d_modelquerykeyvalue 向量中的要素数。
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

每头特征数

99        self.d_k = d_model // heads

头数

101        self.heads = heads

这些变换了多头注意力的querykeyvalue 向量。

104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Softmax 在时间维度上引起人们的注意key

109        self.softmax = nn.Softmax(dim=1)

输出层

112        self.output = nn.Linear(d_model, d_model)

辍学

114        self.dropout = nn.Dropout(dropout_prob)

softmax 之前的缩放系数

116        self.scale = 1 / math.sqrt(self.d_k)

我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算

119        self.attn = None

计算查询和键之间的分数

对于其他变体,例如相对注意力,可以覆盖此方法。

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

计算

129        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask 有形状[seq_len_q, seq_len_k, batch_size] ,其中第一个维度是查询维度。如果查询维度等于它将被广播。

131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

所有头部都使用相同的面具。

142        mask = mask.unsqueeze(-1)

生成的遮罩有形状[seq_len_q, seq_len_k, batch_size, heads]

145        return mask

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ,位置处的查询i 有权访问位置处的键值j

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):

querykey 并且value 有形状[seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)

准备querykeyvalue 进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]

176        scores = self.get_scores(query, key)

音阶分数

179        scores *= self.scale

涂抹面膜

182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))

关注按键序列维度

187        attn = self.softmax(scores)

调试时省去注意力

190        tracker.debug('attn', attn)

申请退学

193        attn = self.dropout(attn)

乘以值

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

保存任何其他计算的注意力

200        self.attn = attn.detach()

连接多个头

203        x = x.reshape(seq_len, batch_size, -1)

输出层

206        return self.output(x)