这是论文《 Attention is All You Need 》中多头注意力的PyTorch教程/实现。该实现的灵感来自《带注释的 Transformer 》。
这是使用基础 Transformer 和 MHA 进行 NLP 自回归的训练代码。
这是一个训练简单 Transformer 的代码实现。
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker
33class PrepareForMultiHeadAttention(nn.Module):
44 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45 super().__init__()
线性层用于线性变换
47 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
注意力头数
49 self.heads = heads
每个头部中向量的维度数量
51 self.d_k = d_k
53 def forward(self, x: torch.Tensor):
输入的形状为[seq_len, batch_size, d_model]
或[batch_size, d_model]
。我们对最后一维应用线性变换,并将其分为多个头。
57 head_shape = x.shape[:-1]
线性变换
60 x = self.linear(x)
将最后一个维度分成多个头部
63 x = x.view(*head_shape, self.heads, self.d_k)
输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_model]
66 return x
这将计算给出的key
、value
和query
向量缩放后的多头注意力。
简单来说,它会找到与查询 (Query) 匹配的键 (key),并获取这些键 (Key) 的值 (Value) 。
它使用查询和键的点积作为衡量它们之间匹配程度的指标。在进行之前,点积会乘以。这样做是为了避免当较大时,大的点积值导致 Softmax 操作输出非常小的梯度。
Softmax 是沿序列(或时间)轴计算的。
69class MultiHeadAttention(nn.Module):
heads
是注意力头的数量。d_model
是向量query
、key
和value
中的特征数量。90 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96 super().__init__()
每个头部的特征数量
99 self.d_k = d_model // heads
注意力头数
101 self.heads = heads
这些将对多头注意力的向量query
、key
和value
进行转换。
104 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
在键( Key )的时间维度上进行注意力 Softmaxkey
109 self.softmax = nn.Softmax(dim=1)
输出层
112 self.output = nn.Linear(d_model, d_model)
Dropout
114 self.dropout = nn.Dropout(dropout_prob)
Softmax 之前的缩放系数
116 self.scale = 1 / math.sqrt(self.d_k)
存储注意力信息,以便在需要时用于记录或其他计算。
119 self.attn = None
121 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
计算或
129 return torch.einsum('ibhd,jbhd->ijbh', query, key)
mask
的形状为[seq_len_q, seq_len_k, batch_size]
,其中第一维是查询维度。如果查询维度等于,则会进行广播。
131 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138 assert mask.shape[1] == key_shape[0]
139 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
所有的头部使用相同的掩码。
142 mask = mask.unsqueeze(-1)
生成的掩码形状为[seq_len_q, seq_len_k, batch_size, heads]
145 return mask
query
、key
和value
是存储查询、键和值向量集合的张量。它们的形状为[seq_len, batch_size, d_model]
。
mask
的形状为[seq_len, seq_len, batch_size]
,mask[i, j, b]
表示批次b
,在位置i
处查询是否有权访问位置j
处的键值对。
147 def forward(self, *,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 mask: Optional[torch.Tensor] = None):
query
,key
和value
的形状为[seq_len, batch_size, d_model]
163 seq_len, batch_size, _ = query.shape
164
165 if mask is not None:
166 mask = self.prepare_mask(mask, query.shape, key.shape)
为注意力计算准备向量query
,key
并value
它们的形状将变为[seq_len, batch_size, heads, d_k]
。
170 query = self.query(query)
171 key = self.key(key)
172 value = self.value(value)
计算注意力分数,这将得到一个形状为[seq_len, seq_len, batch_size, heads]
的张量。
176 scores = self.get_scores(query, key)
缩放分数
179 scores *= self.scale
应用掩码
182 if mask is not None:
183 scores = scores.masked_fill(mask == 0, float('-inf'))
对 Key 序列维度上的注意力进行操作,
187 attn = self.softmax(scores)
调试时保存注意力信息
190 tracker.debug('attn', attn)
应用 Dropout
193 attn = self.dropout(attn)
乘以数值
197 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
为其他计算保存注意力信息
200 self.attn = attn.detach()
连接多个头
203 x = x.reshape(seq_len, batch_size, -1)
输出层
206 return self.output(x)