# 多头注意力 (MHA)

24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

## 为多头注意做好准备

33class PrepareForMultiHeadAttention(nn.Module):
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

49        self.heads = heads

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

57        head_shape = x.shape[:-1]

60        x = self.linear(x)

63        x = x.view(*head_shape, self.heads, self.d_k)

66        return x

## 多头注意模块

Softmax 是沿序列（或时间）的轴计算的。

69class MultiHeadAttention(nn.Module):
• heads 是头的数量。
• d_modelquerykeyvalue 向量中的要素数。
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

99        self.d_k = d_model // heads

101        self.heads = heads

104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Softmax 在时间维度上引起人们的注意key

109        self.softmax = nn.Softmax(dim=1)

112        self.output = nn.Linear(d_model, d_model)

114        self.dropout = nn.Dropout(dropout_prob)

softmax 之前的缩放系数

116        self.scale = 1 / math.sqrt(self.d_k)

119        self.attn = None

### 计算查询和键之间的分数

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

129        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask 有形状[seq_len_q, seq_len_k, batch_size] ，其中第一个维度是查询维度。如果查询维度等于它将被广播。

131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

142        mask = mask.unsqueeze(-1)

145        return mask

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ，位置处的查询i 有权访问位置处的键值j

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):

querykey 并且value 有形状[seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

176        scores = self.get_scores(query, key)

179        scores *= self.scale

182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))

187        attn = self.softmax(scores)

190        tracker.debug('attn', attn)

193        attn = self.dropout(attn)

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

200        self.attn = attn.detach()

203        x = x.reshape(seq_len, batch_size, -1)

206        return self.output(x)