これは、「入門書:言語モデリングのための効率的なトランスフォーマーの検索」という論文をPyTorchで実装したものです。
著者らは、変圧器アーキテクチャの進化的研究を行っています。Primer (プリミティブが検索した Transformer) という検索を使って見つかったアーキテクチャに名前を付けます。Primer EZは、オリジナルのトランスフォーマーと比較して、Primerで最も堅牢な2つの変更を加えたアーキテクチャです。Primer EZはバニラトランスフォーマーよりもはるかに速くトレーニングします
。検索で見つかった最も効果的な変更は、位置ごとのフィードフォワードモジュールで ReLU の代わりに正方形の ReLU を使用することです。
次に効果的な変更は、クエリ、キー、および値のマルチヘッド投影後の深度方向の畳み込みです。畳み込みは、シーケンス次元に沿って、チャネル単位 (深さ方向) で行われます。はっきりさせておきますが、各ヘッドのチャンネル数がの場合、畳み込みはチャンネルごとにカーネルを持つことになります
。38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.transformers import MultiHeadAttention
45class SquaredReLU(Module):
55 def __init__(self):
56 super().__init__()
57 self.relu = nn.ReLU()
59 def forward(self, x: torch.Tensor):
ReLU を適用
61 x = self.relu(x)
スクエア・イット
63 return x * x
66class SpatialDepthWiseConvolution(Module):
d_k
は各ヘッドのチャンネル数71 def __init__(self, d_k: int, kernel_size: int = 3):
75 super().__init__()
76 self.kernel_size = kernel_size
Conv1d
PyTorchのモジュールを使用しています。グループの数をチャネル数と同じになるように設定し、チャネルごとに (異なるカーネルで) 個別の畳み込みを行います。両側にパディングを追加し、kernel_size - 1
後で一番適切な結果になるようにトリミングします
81 self.conv = nn.Conv1d(in_channels=d_k, out_channels=d_k,
82 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k)
x
形がある [seq_len, batch_size, heads, d_k]
84 def forward(self, x: torch.Tensor):
形を取得
90 seq_len, batch_size, heads, d_k = x.shape
に並べ替え [batch_size, heads, d_k, seq_len]
92 x = x.permute(1, 2, 3, 0)
形状を次のように変更 [batch_size * heads, d_k, seq_len]
94 x = x.view(batch_size * heads, d_k, seq_len)
1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]
97 x = self.conv(x)
両側をパディングしたので、kernel_size - 1
最も適切な結果が得られるようにトリミングします
99 x = x[:, :, :-(self.kernel_size - 1)]
形状を次の形式に変更 [batch_size, heads, d_k, seq_len]
101 x = x.view(batch_size, heads, d_k, seq_len)
に並べ替え [seq_len, batch_size, heads, d_k]
103 x = x.permute(3, 0, 1, 2)
106 return x
Multi-Head Attentionの当初の実装を拡張し、クエリ、キー、バリュープロジェクションに空間深度方向のコンボリューションを追加します。
109class MultiDConvHeadAttention(MultiHeadAttention):
117 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
118 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query
self.key
、およびを作成します。self.value
それぞれに空間深度方向の畳み込み層を組み合わせて、、、を置き換えますself.query
。self.key
self.value
📝 これとバニラトランスフォーマーのマルチヘッドアテンションの違いがはっきりとわかるので、このよりクリーンな実装の方が理解しやすいと思います。
128 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
129 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
130 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))