入門:言語モデリングのための効率的なトランスフォーマーの探求

これは、「入門書:言語モデリングのための効率的なトランスフォーマーの検索という論文をPyTorchで実装したものです

著者らは、変圧器アーキテクチャの進化的研究を行っています。Primer (プリミティブが検索した Transformer) という検索を使って見つかったアーキテクチャに名前を付けます。Primer EZは、オリジナルのトランスフォーマーと比較して、Primerで最も堅牢な2つの変更を加えたアーキテクチャです。Primer EZはバニラトランスフォーマーよりもはるかに速くトレーニングします

二乗リル

検索で見つかった最も効果的な変更は、位置ごとのフィードフォワードモジュールで ReLU の代わりに正方形の ReLU を使用することです。

マルチコンバーチングヘッドアテンション (MDHA)

次に効果的な変更は、クエリ、キー、および値のマルチヘッド投影後の深度方向の畳み込みです。畳み込みは、シーケンス次元に沿って、チャネル単位 (深さ方向) で行われます。はっきりさせておきますが、各ヘッドのチャンネル数がの場合畳み込みはチャンネルごとにカーネルを持つことになります

これがプライマーEZの実験コードです

38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.transformers import MultiHeadAttention
45class SquaredReLU(Module):
55    def __init__(self):
56        super().__init__()
57        self.relu = nn.ReLU()
59    def forward(self, x: torch.Tensor):

ReLU を適用

61        x = self.relu(x)

スクエア・イット

63        return x * x

空間深度単位の畳み込み

66class SpatialDepthWiseConvolution(Module):
  • d_k は各ヘッドのチャンネル数
71    def __init__(self, d_k: int, kernel_size: int = 3):
75        super().__init__()
76        self.kernel_size = kernel_size

Conv1d PyTorchのモジュールを使用しています。グループの数をチャネル数と同じになるように設定し、チャネルごとに (異なるカーネルで) 個別の畳み込みを行います。両側にパディングを追加し、kernel_size - 1 後で一番適切な結果になるようにトリミングします

81        self.conv = nn.Conv1d(in_channels=d_k, out_channels=d_k,
82                              kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k)

x 形がある [seq_len, batch_size, heads, d_k]

84    def forward(self, x: torch.Tensor):

形を取得

90        seq_len, batch_size, heads, d_k = x.shape

に並べ替え [batch_size, heads, d_k, seq_len]

92        x = x.permute(1, 2, 3, 0)

形状を次のように変更 [batch_size * heads, d_k, seq_len]

94        x = x.view(batch_size * heads, d_k, seq_len)

1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]

97        x = self.conv(x)

両側をパディングしたので、kernel_size - 1 最も適切な結果が得られるようにトリミングします

99        x = x[:, :, :-(self.kernel_size - 1)]

形状を次の形式に変更 [batch_size, heads, d_k, seq_len]

101        x = x.view(batch_size, heads, d_k, seq_len)

に並べ替え [seq_len, batch_size, heads, d_k]

103        x = x.permute(3, 0, 1, 2)

106        return x

マルチコンバーチングヘッドアテンション (MDHA)

Multi-Head Attentionの当初の実装を拡張し、クエリ、キー、バリュープロジェクションに空間深度方向のコンボリューションを追加します。

109class MultiDConvHeadAttention(MultiHeadAttention):
117    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
118        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query self.key 、およびを作成します。self.value

それぞれに空間深度方向の畳み込み層を組み合わせて、、、を置き換えますself.queryself.key self.value

📝 これとバニラトランスフォーマーのマルチヘッドアテンションの違いがはっきりとわかるので、このよりクリーンな実装の方が理解しやすいと思います

128        self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
129        self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
130        self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))