这是论文《分层转换器是更有效的语言模型》的 PyTorch 实现。
本文介绍了一种分层变压器架构,可以有效地处理长序列。变压器层的前半部分向下采样令牌,后半部分在相同分辨率的层之间使用直接跳过连接向上取样。这与用于视觉任务的 U-Net 有点相似。
他们尝试不同的向上采样和向下采样技术,并使用性能最佳的向上和向下采样技术构建模型,他们称之为沙漏模型。
为简单起@@见,我们在这里实现了最简单的上采样和向下采样技术。稍后我们会考虑添加更复杂(性能更好)的实现。
这是沙漏模型的训练代码。
28from typing import List
29
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask
39class HourGlass(Module):
n_heads
是多头注意层中的头部数量d_model
是令牌嵌入的大小dropout
是辍学概率d_ff
是位置前馈层中隐藏层的维度shortening_factors
是缩短因子清单49 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57 super().__init__()
下采样前的变压器层
60 self.pre = TransformerLayer(d_model=d_model,
66 dropout_prob=dropout)
自动回归掩码
68 self.mask = AutoregressiveMask()
缩短系数(或缩减采样率)
71 k = shortening_factors[0]
我们通过步骤将令牌向右移动,以确保信息不会因为缩减采样和上采样而从未来的代币泄漏到过去的代币上
76 self.shift_right = ShiftRight(k - 1)
缩短或缩减采样层。我们使用最简单的形式——平均汇集。该论文表明,基于注意力的向下采样效果最好,但我们尚未实施。
79 self.shortening = AvgPoolShortening(k)
如果没有更多的缩短(沙漏的中间)
82 if len(shortening_factors) == 1:
中心层是另一个变压器层
84 self.shortened = TransformerLayer(d_model=d_model,
85 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86 feed_forward=FeedForward(d_model, d_ff, dropout),
87 dropout_prob=dropout)
自回归遮罩
89 self.mask_short = AutoregressiveMask()
90 self.hour_glass = None
91 else:
递归地插入另一个沙漏模型
93 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])
向上采样图层。为了简单起见,我们使用天真的向上采样,本文显示基于注意力的采样效果更好。
97 self.up_sampling = NaiveUpSampling(k)
上采样后的最终变压器层
100 self.post = TransformerLayer(d_model=d_model,
101 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102 feed_forward=FeedForward(d_model, d_ff, dropout),
103 dropout_prob=dropout)
105 def forward(self, x: torch.Tensor):
初始变压器层
108 x = self.pre(x=x, mask=self.mask(x))
移位和缩短
111 x_short = self.shortening(self.shift_right(x))
如果我们在沙漏的中心
115 if self.hour_glass is None:
中心变压器层
118 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))
120 else:
122 x_short = self.hour_glass(x_short)
对缩短的序列进行向上采样并添加跳过连接
126 x = x + self.up_sampling(x, x_short)
最终的变压器层
129 x = self.post(x=x, mask=self.mask(x))
132 return x
135class ShiftRight(Module):
shift
是要移位的步数142 def __init__(self, shift: int):
146 super().__init__()
不能为负数
148 assert shift >= 0
150 self.shift = shift
x
是形状张量[seq_len, ...]
152 def forward(self, x: torch.Tensor):
如果移位是返回原来的
157 if self.shift == 0:
158 return x
要追加到左边的零
160 prefix = x.new_zeros([self.shift, *x.shape[1:]])
连接零并截断右边
162 return torch.cat([prefix, x[:-self.shift]])
165class AvgPoolShortening(Module):
k
是缩短系数172 def __init__(self, k: int):
176 super().__init__()
平均池层
178 self.pool = nn.AvgPool1d(k, ceil_mode=True)
x
形状不错[seq_len, batch_size, d_model]
180 def forward(self, x: torch.Tensor):
池化层接受形状[batch_size, d_model, seq_len]
,所以我们排列轴。
186 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)
189class NaiveUpSampling(Module):
k
是缩短系数196 def __init__(self, k: int):
200 super().__init__()
201 self.k = k
x
是向下采样之前有嵌入的张量x_short
是较高密度(待向上采样)表示的张量203 def forward(self, x: torch.Tensor, x_short: torch.Tensor):
在序列维度上重复
209 expanded = torch.repeat_interleave(x_short, self.k, dim=0)
在最后截断多余的嵌入
211 expanded = expanded[:x.shape[0]]
214 return expanded
217class AutoregressiveMask(Module):
222 def __init__(self):
223 super().__init__()
224 self.mask = None
226 def forward(self, x: torch.Tensor):
如果我们尚未创建或大小已更改,请创建蒙版
228 if self.mask is None or self.mask.size(0) != len(x):
233 return self.mask
236class LinearPoolingShortening(Module):
244 def __init__(self):
245 super().__init__()
246 raise NotImplementedError
249class AttentionBasedShortening(Module):
261 def __init__(self):
262 super().__init__()
263 raise NotImplementedError
266class LinearUpSampling(Module):
273 def __init__(self):
274 super().__init__()
275 raise NotImplementedError
278class AttentionBasedUpSampling(Module):
290 def __init__(self):
291 super().__init__()
292 raise NotImplementedError