分层转换器是更有效的语言模型

这是论文《分层转换器是更有效的语言模型》的 PyTorch 实现。

本文介绍了一种分层变压器架构,可以有效地处理长序列。变压器层的前半部分向下采样令牌,后半部分在相同分辨率的层之间使用直接跳过连接向上取样。这与用于视觉任务的 U-Net 有点相似。

他们尝试不同的向上采样和向下采样技术,并使用性能最佳的向上和向下采样技术构建模型,他们称之为沙漏模型。

为简单起@@

见,我们在这里实现了最简单的上采样和向下采样技术。稍后我们会考虑添加更复杂(性能更好)的实现。

这是沙漏模型的训练代码

28from typing import List
29
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
34from labml_nn.transformers import MultiHeadAttention, TransformerLayer
35from labml_nn.transformers.feed_forward import FeedForward
36from labml_nn.transformers.utils import subsequent_mask

沙漏型号

该模型递归地在中间添加图层,同时通过缩减采样来缩短序列。由另一个沙漏模型处理的缩短序列夹在两个普通的变压器层之间。(变压器层具有自注意力层位置前馈层)。

39class HourGlass(Module):
49    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
57        super().__init__()

下采样前的变压器层

60        self.pre = TransformerLayer(d_model=d_model,
62                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
64                                    feed_forward=FeedForward(d_model, d_ff, dropout),

66                                    dropout_prob=dropout)

自动回归掩码

68        self.mask = AutoregressiveMask()

缩短系数(或缩减采样率)

71        k = shortening_factors[0]

我们通过步骤将令牌向右移动,以确保信息不会因为缩减采样和上采样而从未来的代币泄漏到过去的代币上

76        self.shift_right = ShiftRight(k - 1)

缩短或缩减采样层。我们使用最简单的形式——平均汇集。该论文表明,基于注意力的向下采样效果最好,但我们尚未实施。

79        self.shortening = AvgPoolShortening(k)

如果没有更多的缩短(沙漏的中间)

82        if len(shortening_factors) == 1:

中心层是另一个变压器层

84            self.shortened = TransformerLayer(d_model=d_model,
85                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
86                                              feed_forward=FeedForward(d_model, d_ff, dropout),
87                                              dropout_prob=dropout)

自回归遮罩

89            self.mask_short = AutoregressiveMask()
90            self.hour_glass = None
91        else:

递归地插入另一个沙漏模型

93            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

向上采样图层。为了简单起见,我们使用天真的向上采样,本文显示基于注意力的采样效果更好。

97        self.up_sampling = NaiveUpSampling(k)

上采样后的最终变压器层

100        self.post = TransformerLayer(d_model=d_model,
101                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
102                                     feed_forward=FeedForward(d_model, d_ff, dropout),
103                                     dropout_prob=dropout)
105    def forward(self, x: torch.Tensor):

初始变压器层

108        x = self.pre(x=x, mask=self.mask(x))

移位和缩短

111        x_short = self.shortening(self.shift_right(x))

如果我们在沙漏的中心

115        if self.hour_glass is None:

中心变压器层

118            x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))

120        else:

122            x_short = self.hour_glass(x_short)

对缩短的序列进行向上采样并添加跳过连接

126        x = x + self.up_sampling(x, x_short)

最终的变压器层

129        x = self.post(x=x, mask=self.mask(x))

132        return x

向右移操作

这会将序列向右移动给定步数

135class ShiftRight(Module):
  • shift 是要移位的步数
142    def __init__(self, shift: int):
146        super().__init__()

不能为负数

148        assert shift >= 0

150        self.shift = shift
  • x 是形状张量[seq_len, ...]
152    def forward(self, x: torch.Tensor):

如果移位是返回原来的

157        if self.shift == 0:
158            return x

要追加到左边的零

160        prefix = x.new_zeros([self.shift, *x.shape[1:]])

连接零并截断右边

162        return torch.cat([prefix, x[:-self.shift]])

池平均缩短

这会按给定因子向下采样,并使用平均汇集

165class AvgPoolShortening(Module):
  • k 是缩短系数
172    def __init__(self, k: int):
176        super().__init__()

平均池层

178        self.pool = nn.AvgPool1d(k, ceil_mode=True)
  • x 形状不错[seq_len, batch_size, d_model]
180    def forward(self, x: torch.Tensor):

池化层接受形状[batch_size, d_model, seq_len] ,所以我们排列轴。

186        return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)

朴素的向上采样

这通过重复向上采样

189class NaiveUpSampling(Module):
  • k 是缩短系数
196    def __init__(self, k: int):
200        super().__init__()
201        self.k = k
  • x 是向下采样之前有嵌入的张量
  • x_short 是较高密度(待向上采样)表示的张量
203    def forward(self, x: torch.Tensor, x_short: torch.Tensor):

在序列维度上重复

209        expanded = torch.repeat_interleave(x_short, self.k, dim=0)

在最后截断多余的嵌入

211        expanded = expanded[:x.shape[0]]

214        return expanded

生成自动回归掩码

217class AutoregressiveMask(Module):
222    def __init__(self):
223        super().__init__()
224        self.mask = None
226    def forward(self, x: torch.Tensor):

如果我们尚未创建或大小已更改,请创建蒙版

228        if self.mask is None or self.mask.size(0) != len(x):

后续的掩码,将掩盖令牌以免看到未来的代币

230            self.mask = subsequent_mask(len(x)).to(x.device)

233        return self.mask

🚧 用于缩减采样的线性池

这将需要合并的连续令牌嵌入连接起来,并进行线性变换以将其映射到单个令牌嵌入的大小。

236class LinearPoolingShortening(Module):
244    def __init__(self):
245        super().__init__()
246        raise NotImplementedError

🚧 注意向下采样

其中是平均池化或线性池。

249class AttentionBasedShortening(Module):
261    def __init__(self):
262        super().__init__()
263        raise NotImplementedError

🚧 用于向上采样的线性投影

将@@

密集令牌嵌入进行线性投影,使其大小为

266class LinearUpSampling(Module):
273    def __init__(self):
274        super().__init__()
275        raise NotImplementedError

🚧 基于注意力的向上采样

在哪里

278class AttentionBasedUpSampling(Module):
290    def __init__(self):
291        super().__init__()
292        raise NotImplementedError