反馈变压器

这是 PyTorch 对《使用反馈存储器访问序列变压器中的更高层次表示》一文的 PyT orch 实现。

普通的变压器会并行处理代币。每个变压器层都注意前一层的输出。反馈变压器注意前面步骤中所有层的输出。因此,这会增加重复性,我们需要逐个代币进行处理。这会显著减慢训练速度(大约 5 到 10 倍,具体取决于序列长度)。但是,在预测反馈变换器时,速度更快,因为如果你缓存了内存向量,你可以预测下一个标记。

为了加快训练速度,本文讨论了从短序列长度开始并逐渐增加序列长度的问题。他们还讨论了使用预训练的并行变压器作为起点。

原始反馈变压器不保留所有层的输出。相反,它保留所有图层输出的加权总和。这减少了预测期间用于缓存的内存。这个文件的前半部分实现了这一点。

更新后的反馈变压器共享权重用于计算各层之间的密钥和值。然后,我们只计算每个步骤的键和值一次,并将其缓存。这个文件的后半部分实现了这一点。我们实现了一个自定义 PyTorch 函数来提高性能。

这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练反馈转换器的笔记本。

Open In Colab

42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers.feed_forward import FeedForward
50from labml_nn.transformers.mha import PrepareForMultiHeadAttention
51from labml_nn.utils import clone_module_list

反馈关注

本模块计算重复注意力,类似于原版《变形金刚》论文中的注意力。

54class FeedbackAttention(Module):
  • “heads” 是注意头的数量
  • d_model 是变压器中的特征数
  • dropout_prob 注意力丢失的概率是多少
  • is_kv_precomputed 是键值、值张量是否已经计算过
65    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
66                 is_kv_precomputed: bool = False):
74        super().__init__()

每头特征数

77        self.d_k = d_model // heads

79        self.heads = heads

这些改变了query 多头注意力。

82        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

这些改变了多value 头注意力的key 和。

84        if not is_kv_precomputed:
85            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
86            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

键和值已计算

88        else:
89            self.key = None
90            self.value = None

输出层

93        self.output = nn.Linear(d_model, d_model)

辍学

95        self.dropout = nn.Dropout(dropout_prob)

softmax 之前的缩放系数

97        self.scale = 1 / math.sqrt(self.d_k)

Softmax 在时间维度上引起人们的注意key

100        self.softmax = nn.Softmax(dim=0)

相对位置的数量

103        self.P = 2 ** 12

键相对于查询的相对位置嵌入。

106        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)

键相对于查询的相对位置嵌入偏差。

108        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)

查询的位置嵌入与查询的位置无关

110        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算

113        self.attn = None

获得注意力分数

我们使用相对位置编码来表示注意力,类似于 Transf ormer-XL 纸张的相对多头注意力

从当前步骤的查询到键入步骤(相对于当前步骤)的注意是,

其中,是原始嵌入的线性变换是位置编码的线性变换

我们将术语替换为

115    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

143        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]

145        query_pos_bias = self.query_pos_bias[None, :, :]

147        key_pos_bias = self.key_pos_bias[-key.shape[0]:]

150        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)

152        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]

155        return ac + bd
  • query 有形状[batch_size, d_model]
  • key 而且value 有形状[seq_len, batch_size, d_model]
157    def forward(self, *,
158                query: torch.Tensor,
159                key: torch.Tensor,
160                value: torch.Tensor):

做好准备query keyvalue 进行注意力计算key ,然后value 就会有形状[seq_len, batch_size, heads, d_k] 而且query 会有形状[batch_size, heads, d_k]

169        query = self.query(query)
170        if self.key:
171            key = self.key(key)
172        if self.value:
173            value = self.value(value)

计算注意力分数。结果为形状的张量[seq_len, batch_size, heads]

177        scores = self.get_scores(query, key)

音阶分数

180        scores *= self.scale

软最大

183        attn = self.softmax(scores)

申请退学

186        attn = self.dropout(attn)

乘以值

189        x = torch.einsum("jbh,jbhd->bhd", attn, value)

连接多个头

192        x = x.reshape(x.shape[0], -1)

输出层

195        return self.output(x)

反馈变压器层

这在反馈变压器中实现了单个变压器层。

198class FeedbackTransformerLayer(Module):
  • d_model 是变压器中的特征数
  • attn 是反馈关注模块
  • feed_forward 是位置前馈层
  • dropout_prob 是注意和前馈后辍学层的丢失概率
205    def __init__(self, *,
206                 d_model: int,
207                 attn: FeedbackAttention,
208                 feed_forward: FeedForward,
209                 dropout_prob: float):
216        super().__init__()

变压器尺寸

218        self.size = d_model

220        self.attn = attn
221        self.feed_forward = feed_forward
222        self.dropout = nn.Dropout(dropout_prob)

归一化层

225        self.norm_self_attn = nn.LayerNorm([d_model])
226        self.norm_ff = nn.LayerNorm([d_model])
228    def forward(self, *,
229                x: torch.Tensor,
230                key: Optional[torch.Tensor],
231                value: Optional[torch.Tensor]):

如果有记忆

233        if key is not None:

在进行自我注意之前对向量进行归一化

235            z = self.norm_self_attn(x)

通过自我关注,即关键和价值来自自我

237            self_attn = self.attn(query=z, key=key, value=value)

添加自我关注的结果

239            x = x + self.dropout(self_attn)

标准化以进行前馈

242        z = self.norm_ff(x)

通过前馈网络

244        ff = self.feed_forward(z)

将前馈结果添加回来

246        x = x + self.dropout(ff)

249        return x

反馈变压器模块

252class FeedbackTransformer(Module):
  • layer 是反馈变压器层,我们为每层克隆它
  • n_layers 是变压器中的层数
257    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
263        super().__init__()

制作变压器层的副本

265        self.layers = clone_module_list(layer, n_layers)

最终归一化层

267        self.norm = nn.LayerNorm([layer.size])

内存向量计算为每个图层表示的加权总和。这是该参数的权重参数。

270        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax 用于计算加权总和之前的权重

272        self.softmax = nn.Softmax(0)
  • x_seq 是带形状的输入[seq_len, batch_size, d_model]
274    def forward(self, x_seq: torch.Tensor):

沿序列轴将输入拆分为一个列表

280        x_seq = torch.unbind(x_seq, dim=0)

存储输出的列表

282        res = []

存储记忆向量的列表

284        mem = []

对于每个输入步骤

286        for x in x_seq:

存储图层输出的列表

288            layer_outputs = [x]

如果有内存,则将它们堆叠成一个向量

291            mem_tensor = torch.stack(mem) if mem else None

穿过每一层

294            for layer in self.layers:

获取图层输出

296                x = layer(x=x, key=mem_tensor, value=mem_tensor)

将它们追加到图层输出列表中

298                layer_outputs.append(x)

将层输出堆叠到张量

301            layer_outputs = torch.stack(layer_outputs)

将内存矢量计算为图层输出的加权总和

303            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))

将输出追加到结果中

305            res.append(x)

堆叠输出张量

308        res = torch.stack(res)

规范化输出

310        return self.norm(res)

层间共享密钥和值

堆栈函数实现

我们实现了一个自定义函数,而不是追加到python列表然后做torch.stack 。与顺序中每个步骤的调用相比torch.stack ,这大大提高了性能。每次调用torch.stack 时,它都会创建一个新的张量,而此方法和随附的类Stack 共享每个步骤的内存。

317class StackFunction(torch.autograd.Function):
  • ctx 是函数的上下文(它允许我们缓存东西)
  • memory 是共享内存张量,我们在其中堆叠和存储每个步骤的值(键和值)
  • memory_grad 是存储和累积每步梯度的共享内存张量
  • last 是最后一个堆叠的值
  • n 是步数(即堆栈的大小)

这将返回步长到的堆叠张量n

329    @staticmethod
330    def forward(ctx, memory, memory_grad, last, n):

缓存累积渐变

342        ctx._mem_grad = memory_grad

缓存堆栈的大小

344        ctx._n = n

返回堆栈

346        return memory[:n + 1]
  • grad_output 是相对于aboutforward 函数输出的梯度

这会累积共享内存张量中的梯度,并返回相对于堆栈中last 结果的梯度。

348    @staticmethod
349    def backward(ctx, grad_output):

获取堆栈的当前大小

357        n = ctx._n

获取累积的梯度

359        memory_grad = ctx._mem_grad

添加渐变

361        memory_grad[:n + 1] += grad_output

将 w.r.t 的梯度返回到堆栈中的最后一个值

363        return None, None, memory_grad[n], None

堆叠模块

这使用上面定义的堆栈函数,并进行必要的初始化。

366class Stack:
  • max_len 是堆栈的最大大小
373    def __init__(self, max_len: int):
377        self.max_len = max_len
378        self.memory = None
379        self.memory_grad = None
380        self.last = None
381        self.n = -1
382        self.last_get_n = -1
  • n 是堆栈的大小
  • value 是需要添加到堆栈的张量
384    def append(self, n: int, value: torch.Tensor):

添加值后,你需要获取(使用)堆栈。否则,此实现将失败

392        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"

在没有渐变的情况下执行此操作

395        with torch.no_grad():

初始化共享内存张量以保留堆栈

397            if self.memory is None or self.memory.shape[1:] != value.shape:

只有当堆栈为空时才会发生这种情况

399                assert n == 0

为堆栈创建张量

401                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)

创建张量来累积梯度

403                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)

内存已经初始化,但我们正在重置堆栈。

这可能是另一个类似的函数reset ,但我们发现它更容易使用。

408            elif n == 0:

重置累积梯度

410                self.memory_grad.fill_(0.)

将值设置在堆栈的正确位置

413            self.memory.data[n] = value.detach()

跟踪堆栈(用于调试)

415            self.n = n

跟踪添加到堆栈的最后一个值。为了让梯度向后传播,我们需要将其传递给。StackFunction

420        self.last = value

返回堆栈

422    def get(self):

跟踪使用时堆栈的大小。这用于健全性检入append

429        self.last_get_n = self.n
全力@@

以赴,StackFunction 这样 PyStackFunction.backwards Torch 在反向传播期间就会调用它。

432        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)

释放内存

434    def free(self):
439        self.memory = None
440        self.memory_grad = None
441        self.last = None

更新了反馈变压器模块

这是更新的反馈转换器模块,用于缓存键和值。

444class FeedbackTransformerKV(Module):
  • layer 是反馈变压器层,我们为每层克隆它
  • n_layers 是变压器中的层数
  • d_model 是变压器中的特征数
  • “heads” 是注意头的数量
451    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
459        super().__init__()

制作变压器层的副本

461        self.layers = clone_module_list(layer, n_layers)

最终归一化层

463        self.norm = nn.LayerNorm([layer.size])

内存向量计算为每个图层表示的加权总和。这是该参数的权重参数。

466        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax 用于计算加权总和之前的权重

468        self.softmax = nn.Softmax(0)

头部特征的数量

471        d_k = d_model // heads

转换嵌入(内存)以获取密钥的模块

473        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

转换嵌入(内存)以获取密钥的模块

475        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

堆叠按键的存储空间

478        self.mem_key = Stack(512)

堆叠值的内存

480        self.mem_value = Stack(512)
  • x_seq 是带形状的输入[seq_len, batch_size, d_model]
482    def forward(self, x_seq: torch.Tensor):

沿序列轴将输入拆分为一个列表

488        x_seq = torch.unbind(x_seq, dim=0)

存储输出的列表

490        res = []

对于每个输入步骤

492        for step, x in enumerate(x_seq):

存储图层输出的列表

494            layer_outputs = [x]

键和值的堆栈

497            key_tensor = None
498            value_tensor = None

如果我们超出了初始步骤,则获取键和值张量

500            if step > 0:
501                key_tensor = self.mem_key.get()
502                value_tensor = self.mem_value.get()

穿过每一层

505            for layer in self.layers:

获取图层输出

507                x = layer(x=x, key=key_tensor, value=value_tensor)

将它们追加到图层输出列表中

509                layer_outputs.append(x)

将层输出堆叠到张量

512            layer_outputs = torch.stack(layer_outputs)

将内存矢量计算为图层输出的加权总和

514            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))

从内存中计算密钥并将其添加到堆栈中

516            self.mem_key.append(step, self.key(mem))

计算内存中的值并将其添加到堆栈中

518            self.mem_value.append(step, self.value(mem))

将输出追加到结果中

520            res.append(x)

堆叠输出张量

523        res = torch.stack(res)

规范化输出

525        return self.norm(res)
527    def free(self):
528        self.mem_key.free()
529        self.mem_value.free()