这是 PyTorch 对《使用反馈存储器访问序列变压器中的更高层次表示》一文的 PyT orch 实现。
普通的变压器会并行处理代币。每个变压器层都注意前一层的输出。反馈变压器注意前面步骤中所有层的输出。因此,这会增加重复性,我们需要逐个代币进行处理。这会显著减慢训练速度(大约 5 到 10 倍,具体取决于序列长度)。但是,在预测反馈变换器时,速度更快,因为如果你缓存了内存向量,你可以预测下一个标记。
为了加快训练速度,本文讨论了从短序列长度开始并逐渐增加序列长度的问题。他们还讨论了使用预训练的并行变压器作为起点。
原始反馈变压器不保留所有层的输出。相反,它保留所有图层输出的加权总和。这减少了预测期间用于缓存的内存。这个文件的前半部分实现了这一点。
更新后的反馈变压器共享权重,用于计算各层之间的密钥和值。然后,我们只计算每个步骤的键和值一次,并将其缓存。这个文件的后半部分实现了这一点。我们实现了一个自定义 PyTorch 函数来提高性能。
这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练反馈转换器的笔记本。
42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers.feed_forward import FeedForward
50from labml_nn.transformers.mha import PrepareForMultiHeadAttention
51from labml_nn.utils import clone_module_list
54class FeedbackAttention(Module):
d_model
是变压器中的特征数dropout_prob
注意力丢失的概率是多少is_kv_precomputed
是键值、值张量是否已经计算过65 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
66 is_kv_precomputed: bool = False):
74 super().__init__()
每头特征数
77 self.d_k = d_model // heads
79 self.heads = heads
这些改变了query
多头注意力。
82 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
这些改变了多value
头注意力的key
和。
84 if not is_kv_precomputed:
85 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
86 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
键和值已计算
88 else:
89 self.key = None
90 self.value = None
输出层
93 self.output = nn.Linear(d_model, d_model)
辍学
95 self.dropout = nn.Dropout(dropout_prob)
softmax 之前的缩放系数
97 self.scale = 1 / math.sqrt(self.d_k)
Softmax 在时间维度上引起人们的注意key
100 self.softmax = nn.Softmax(dim=0)
相对位置的数量
103 self.P = 2 ** 12
键相对于查询的相对位置嵌入。
106 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
键相对于查询的相对位置嵌入偏差。
108 self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
查询的位置嵌入与查询的位置无关
110 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算
113 self.attn = None
我们使用相对位置编码来表示注意力,类似于 Transf ormer-XL 纸张的相对多头注意力。
从当前步骤的查询到键入步骤(相对于当前步骤)的注意是,
其中,是原始嵌入的线性变换,是位置编码的线性变换。
我们将术语替换为。
115 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
143 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
145 query_pos_bias = self.query_pos_bias[None, :, :]
147 key_pos_bias = self.key_pos_bias[-key.shape[0]:]
150 ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
152 bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
155 return ac + bd
query
有形状[batch_size, d_model]
key
而且value
有形状[seq_len, batch_size, d_model]
157 def forward(self, *,
158 query: torch.Tensor,
159 key: torch.Tensor,
160 value: torch.Tensor):
做好准备query
key
,value
进行注意力计算key
,然后value
就会有形状[seq_len, batch_size, heads, d_k]
而且query
会有形状[batch_size, heads, d_k]
169 query = self.query(query)
170 if self.key:
171 key = self.key(key)
172 if self.value:
173 value = self.value(value)
计算注意力分数。结果为形状的张量[seq_len, batch_size, heads]
177 scores = self.get_scores(query, key)
音阶分数
180 scores *= self.scale
软最大
183 attn = self.softmax(scores)
申请退学
186 attn = self.dropout(attn)
乘以值
189 x = torch.einsum("jbh,jbhd->bhd", attn, value)
连接多个头
192 x = x.reshape(x.shape[0], -1)
输出层
195 return self.output(x)
198class FeedbackTransformerLayer(Module):
d_model
是变压器中的特征数attn
是反馈关注模块feed_forward
是位置前馈层dropout_prob
是注意和前馈后辍学层的丢失概率205 def __init__(self, *,
206 d_model: int,
207 attn: FeedbackAttention,
208 feed_forward: FeedForward,
209 dropout_prob: float):
216 super().__init__()
变压器尺寸
218 self.size = d_model
220 self.attn = attn
221 self.feed_forward = feed_forward
222 self.dropout = nn.Dropout(dropout_prob)
归一化层
225 self.norm_self_attn = nn.LayerNorm([d_model])
226 self.norm_ff = nn.LayerNorm([d_model])
228 def forward(self, *,
229 x: torch.Tensor,
230 key: Optional[torch.Tensor],
231 value: Optional[torch.Tensor]):
如果有记忆
233 if key is not None:
在进行自我注意之前对向量进行归一化
235 z = self.norm_self_attn(x)
通过自我关注,即关键和价值来自自我
237 self_attn = self.attn(query=z, key=key, value=value)
添加自我关注的结果
239 x = x + self.dropout(self_attn)
标准化以进行前馈
242 z = self.norm_ff(x)
通过前馈网络
244 ff = self.feed_forward(z)
将前馈结果添加回来
246 x = x + self.dropout(ff)
249 return x
252class FeedbackTransformer(Module):
layer
是反馈变压器层,我们为每层克隆它n_layers
是变压器中的层数257 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
263 super().__init__()
制作变压器层的副本
265 self.layers = clone_module_list(layer, n_layers)
最终归一化层
267 self.norm = nn.LayerNorm([layer.size])
内存向量计算为每个图层表示的加权总和。这是该参数的权重参数。
270 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
Softmax 用于计算加权总和之前的权重
272 self.softmax = nn.Softmax(0)
x_seq
是带形状的输入[seq_len, batch_size, d_model]
274 def forward(self, x_seq: torch.Tensor):
沿序列轴将输入拆分为一个列表
280 x_seq = torch.unbind(x_seq, dim=0)
存储输出的列表
282 res = []
存储记忆向量的列表
284 mem = []
对于每个输入步骤
286 for x in x_seq:
存储图层输出的列表
288 layer_outputs = [x]
如果有内存,则将它们堆叠成一个向量
291 mem_tensor = torch.stack(mem) if mem else None
穿过每一层
294 for layer in self.layers:
获取图层输出
296 x = layer(x=x, key=mem_tensor, value=mem_tensor)
将它们追加到图层输出列表中
298 layer_outputs.append(x)
将层输出堆叠到张量
301 layer_outputs = torch.stack(layer_outputs)
将内存矢量计算为图层输出的加权总和
303 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
将输出追加到结果中
305 res.append(x)
堆叠输出张量
308 res = torch.stack(res)
规范化输出
310 return self.norm(res)
我们实现了一个自定义函数,而不是追加到python列表然后做torch.stack
。与顺序中每个步骤的调用相比torch.stack
,这大大提高了性能。每次调用torch.stack
时,它都会创建一个新的张量,而此方法和随附的类Stack
共享每个步骤的内存。
317class StackFunction(torch.autograd.Function):
ctx
是函数的上下文(它允许我们缓存东西)memory
是共享内存张量,我们在其中堆叠和存储每个步骤的值(键和值)memory_grad
是存储和累积每步梯度的共享内存张量last
是最后一个堆叠的值n
是步数(即堆栈的大小)这将返回步长到的堆叠张量n
。
329 @staticmethod
330 def forward(ctx, memory, memory_grad, last, n):
缓存累积渐变
342 ctx._mem_grad = memory_grad
缓存堆栈的大小
344 ctx._n = n
返回堆栈
346 return memory[:n + 1]
348 @staticmethod
349 def backward(ctx, grad_output):
获取堆栈的当前大小
357 n = ctx._n
获取累积的梯度
359 memory_grad = ctx._mem_grad
添加渐变
361 memory_grad[:n + 1] += grad_output
将 w.r.t 的梯度返回到堆栈中的最后一个值
363 return None, None, memory_grad[n], None
366class Stack:
max_len
是堆栈的最大大小373 def __init__(self, max_len: int):
377 self.max_len = max_len
378 self.memory = None
379 self.memory_grad = None
380 self.last = None
381 self.n = -1
382 self.last_get_n = -1
n
是堆栈的大小value
是需要添加到堆栈的张量384 def append(self, n: int, value: torch.Tensor):
添加值后,你需要获取(使用)堆栈。否则,此实现将失败
392 assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
在没有渐变的情况下执行此操作
395 with torch.no_grad():
初始化共享内存张量以保留堆栈
397 if self.memory is None or self.memory.shape[1:] != value.shape:
只有当堆栈为空时才会发生这种情况
399 assert n == 0
为堆栈创建张量
401 self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)
创建张量来累积梯度
403 self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)
408 elif n == 0:
重置累积梯度
410 self.memory_grad.fill_(0.)
将值设置在堆栈的正确位置
413 self.memory.data[n] = value.detach()
跟踪堆栈(用于调试)
415 self.n = n
跟踪添加到堆栈的最后一个值。为了让梯度向后传播,我们需要将其传递给。StackFunction
420 self.last = value
返回堆栈
422 def get(self):
跟踪使用时堆栈的大小。这用于健全性检入append
。
429 self.last_get_n = self.n
以赴,StackFunction
这样 PyStackFunction.backwards
Torch 在反向传播期间就会调用它。
432 return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
释放内存
434 def free(self):
439 self.memory = None
440 self.memory_grad = None
441 self.last = None
444class FeedbackTransformerKV(Module):
layer
是反馈变压器层,我们为每层克隆它n_layers
是变压器中的层数d_model
是变压器中的特征数451 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
459 super().__init__()
制作变压器层的副本
461 self.layers = clone_module_list(layer, n_layers)
最终归一化层
463 self.norm = nn.LayerNorm([layer.size])
内存向量计算为每个图层表示的加权总和。这是该参数的权重参数。
466 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
Softmax 用于计算加权总和之前的权重
468 self.softmax = nn.Softmax(0)
头部特征的数量
471 d_k = d_model // heads
转换嵌入(内存)以获取密钥的模块
473 self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
转换嵌入(内存)以获取密钥的模块
475 self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
堆叠按键的存储空间
478 self.mem_key = Stack(512)
堆叠值的内存
480 self.mem_value = Stack(512)
x_seq
是带形状的输入[seq_len, batch_size, d_model]
482 def forward(self, x_seq: torch.Tensor):
沿序列轴将输入拆分为一个列表
488 x_seq = torch.unbind(x_seq, dim=0)
存储输出的列表
490 res = []
对于每个输入步骤
492 for step, x in enumerate(x_seq):
存储图层输出的列表
494 layer_outputs = [x]
键和值的堆栈
497 key_tensor = None
498 value_tensor = None
如果我们超出了初始步骤,则获取键和值张量
500 if step > 0:
501 key_tensor = self.mem_key.get()
502 value_tensor = self.mem_value.get()
穿过每一层
505 for layer in self.layers:
获取图层输出
507 x = layer(x=x, key=key_tensor, value=value_tensor)
将它们追加到图层输出列表中
509 layer_outputs.append(x)
将层输出堆叠到张量
512 layer_outputs = torch.stack(layer_outputs)
将内存矢量计算为图层输出的加权总和
514 mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
从内存中计算密钥并将其添加到堆栈中
516 self.mem_key.append(step, self.key(mem))
计算内存中的值并将其添加到堆栈中
518 self.mem_value.append(step, self.value(mem))
将输出追加到结果中
520 res.append(x)
堆叠输出张量
523 res = torch.stack(res)
规范化输出
525 return self.norm(res)
527 def free(self):
528 self.mem_key.free()
529 self.mem_value.free()