这是 PyTorch 中 Transfor mer-XL:超越固定长度上下文的专心语言模型的实现。
Transformer 的注意力跨度有限,等于并行训练序列的长度。所有这些位置都有固定的位置编码。Transformer XL 通过让每个位置关注过去预先计算的嵌入次数,从而延长了这种注意力跨度。例如,如果上下文长度为,它将保留前一批长度的所有层的嵌入并将其馈送到当前步骤。如果我们使用固定位置编码,这些预先计算的嵌入将与当前上下文具有相同的位置。它们引入了相对位置编码,其中位置编码是在注意力计算时引入的。
相对多头注意力的带注释的实现已经开始relative_mha.py
了。
35from typing import List, Optional
36
37import torch
38import torch.nn as nn
39
40from labml_helpers.module import Module
41from labml_nn.utils import clone_module_list
42from .relative_mha import RelativeMultiHeadAttention
43from ..feed_forward import FeedForward
46class TransformerXLLayer(Module):
52 def __init__(self, *,
53 d_model: int,
54 self_attn: RelativeMultiHeadAttention,
55 feed_forward: FeedForward,
56 dropout_prob: float):
63 super().__init__()
64 self.size = d_model
65 self.self_attn = self_attn
66 self.feed_forward = feed_forward
67 self.dropout = nn.Dropout(dropout_prob)
68 self.norm_self_attn = nn.LayerNorm([d_model])
69 self.norm_ff = nn.LayerNorm([d_model])
x
是令牌级特征形状向量的张量[seq_len, batch_size, d_model]
mem
是过去令牌级别特征形状向量的张量[mem_len, batch_size, d_model]
mask
是形状的矩阵[seq_len, mem_len + seq_len, batch_size]
或[seq_len, mem_len + seq_len, 1]
。mask[i, j]
如果 tokeni
可以在处看到令牌,则为 truej
。71 def forward(self, *,
72 x: torch.Tensor,
73 mem: Optional[torch.Tensor],
74 mask: torch.Tensor):
在进行自我注意之前对向量进行归一化
82 z = self.norm_self_attn(x)
如果有记忆
84 if mem is not None:
规范化它
86 mem = self.norm_self_attn(mem)
连接与z
88 m_z = torch.cat((mem, z), dim=0)
如果没有内存,则忽略
90 else:
91 m_z = z
注意
93 self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
添加关注结果
95 x = x + self.dropout(self_attn)
标准化以进行前馈
98 z = self.norm_ff(x)
通过前馈网络
100 ff = self.feed_forward(z)
将前馈结果添加回来
102 x = x + self.dropout(ff)
105 return x
108class TransformerXL(Module):
115 def __init__(self, layer: TransformerXLLayer, n_layers: int):
116 super().__init__()
制作变压器层的副本
118 self.layers = clone_module_list(layer, n_layers)
最终归一化层
120 self.norm = nn.LayerNorm([layer.size])
x
是嵌入形状向量的令牌的张量[seq_len, batch_size, d_model]
mem
是过去令牌级别的张量列表,每个层的形状[mem_len, batch_size, d_model]
向量特征mask
是掩码矩阵122 def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
用于存储令牌级特征向量的列表,这些向量将成为下一个连续批次的记忆。
131 new_mem = []
穿过每个变压器层
133 for i, layer in enumerate(self.layers):
添加到特征向量列表中
135 new_mem.append(x.detach())
记忆
137 m = mem[i] if mem else None
穿过变压器 XL 层
139 x = layer(x=x, mem=m, mask=mask)
最后,对向量进行归一化
141 return self.norm(x), new_mem