这是 P rimer:为语言建模寻找高效转换器论文的 PyTorch 实现。
作者对变压器架构进行了进化探索。他们使用搜索 Primer(Primitives 搜索 Transformer)命名找到的架构。与原始变压器相比,P@@ rimer EZ 是在 Primer 中进行了两项最强大的修改的架构。Primer EZ 的训练速度比原版变压器快很多。
搜索发现的最有效的修改是在位置前馈模块中使用方形 ReLU 而不是 Re LU。
下一个有效的修改是在查询、键和值的多头投影之后的深度卷积。卷积沿着序列维度和每个通道(深度)进行。需要明确的是,如果每个信头中的通道数为,则卷积将为每个通道都有内核。
38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.transformers import MultiHeadAttention
55 def __init__(self):
56 super().__init__()
57 self.relu = nn.ReLU()
59 def forward(self, x: torch.Tensor):
申请 ReLU
61 x = self.relu(x)
把它弄平了
63 return x * x
66class SpatialDepthWiseConvolution(Module):
d_k
是每个 head 中的通道数71 def __init__(self, d_k: int, kernel_size: int = 3):
75 super().__init__()
76 self.kernel_size = kernel_size
我们使用 PyTorch 的Conv1d
模块。我们将组的数量设置为等于通道数,以便它对每个通道进行单独的卷积(使用不同的内核)。我们在两边添加填充,然后裁剪最右边的kernel_size - 1
结果
81 self.conv = nn.Conv1d(in_channels=d_k, out_channels=d_k,
82 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k)
x
有形状[seq_len, batch_size, heads, d_k]
84 def forward(self, x: torch.Tensor):
得到形状
90 seq_len, batch_size, heads, d_k = x.shape
排列为[batch_size, heads, d_k, seq_len]
92 x = x.permute(1, 2, 3, 0)
将形状改为[batch_size * heads, d_k, seq_len]
94 x = x.view(batch_size * heads, d_k, seq_len)
一维卷积接受以下形式的输入[N, channels, sequence]
97 x = self.conv(x)
裁剪最右边的kernel_size - 1
结果,因为我们填充了两边
99 x = x[:, :, :-(self.kernel_size - 1)]
重塑为[batch_size, heads, d_k, seq_len]
101 x = x.view(batch_size, heads, d_k, seq_len)
排列为[seq_len, batch_size, heads, d_k]
103 x = x.permute(3, 0, 1, 2)
106 return x
109class MultiDConvHeadAttention(MultiHeadAttention):
117 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
118 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention 将创建查询、键和价值投影模块self.query
self.key
、和self.value
。
我们将空间深度卷积层组合到每个层上,并替换self.query
self.key
、和self.value
。
📝 我们认为这种更简洁的实现更容易理解,因为它清楚地显示了这与普通变压器多头关注之间的区别。
128 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
129 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
130 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))