プライマー EZ バリエーション

いくつかのバリエーションを試して、Primer EZのどの変更が最もメリットがあるかを確認しました。

12import torch
13from torch import nn
14
15from labml_helpers.module import Module
16from labml_nn.transformers import MultiHeadAttention

空間奥行き共有コンボリューション

すべてのチャネルで同じカーネルを共有します。

19class SpatialDepthWiseSharedConvolution(Module):
26    def __init__(self, kernel_size: int = 3):
27        super().__init__()
28        self.kernel_size = kernel_size

Conv1d PyTorchのモジュールを使用しています。両側にパディングを追加し、kernel_size - 1 後で一番適切な結果になるようにトリミングします

33        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))

x 形がある [seq_len, batch_size, heads, d_k]

35    def forward(self, x: torch.Tensor):

形を取得

41        seq_len, batch_size, heads, d_k = x.shape

に並べ替え [batch_size, heads, d_k, seq_len]

43        x = x.permute(1, 2, 3, 0)

形状を次のように変更 [batch_size * heads * d_k, seq_len]

45        x = x.view(batch_size * heads * d_k, 1, seq_len)

1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]

48        x = self.conv(x)

両側をパディングしたので、kernel_size - 1 最も適切な結果が得られるようにトリミングします

50        x = x[:, :, :-(self.kernel_size - 1)]

形状を次の形式に変更 [batch_size, heads, d_k, seq_len]

52        x = x.view(batch_size, heads, d_k, seq_len)

に並べ替え [seq_len, batch_size, heads, d_k]

54        x = x.permute(3, 0, 1, 2)

57        return x

多層共有型頭脳アテンション

Multi-Head Attentionの当初の実装を拡張し、クエリ、キー、バリュープロジェクションに空間深度共有コンボリューションを追加します。

60class MultiDSharedConvHeadAttention(MultiHeadAttention):
68    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
69        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query self.key 、およびを作成します。self.value

それぞれに空間深度共有畳み込み層を組み合わせて、、、を置き換えますself.queryself.key self.value

76        self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
77        self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
78        self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())

頭あたりの空間深度コンボリューション

81class SpatialDepthWisePerHeadConvolution(Module):
  • heads は頭の数です
  • d_k は各ヘッドのチャンネル数
86    def __init__(self, heads: int, d_k: int, kernel_size: int = 3):
91        super().__init__()
92        self.kernel_size = kernel_size

Conv1d PyTorchのモジュールを使用しています。グループの数を各ヘッドのチャネル数と同じになるように設定し、チャネルとヘッドごとに個別の畳み込みを (異なるカーネルで) 行います。両側にパディングを追加し、kernel_size - 1 後で一番適切な結果になるようにトリミングします

98        self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
99                              kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)

x 形がある [seq_len, batch_size, heads, d_k]

101    def forward(self, x: torch.Tensor):

形を取得

107        seq_len, batch_size, heads, d_k = x.shape

に並べ替え [batch_size, heads, d_k, seq_len]

109        x = x.permute(1, 2, 3, 0)

形状を次のように変更 [batch_size heads * d_k, seq_len]

111        x = x.view(batch_size, heads * d_k, seq_len)

1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]

114        x = self.conv(x)

両側をパディングしたので、kernel_size - 1 最も適切な結果が得られるようにトリミングします

116        x = x[:, :, :-(self.kernel_size - 1)]

形状を次の形式に変更 [batch_size, heads, d_k, seq_len]

118        x = x.view(batch_size, heads, d_k, seq_len)

に並べ替え [seq_len, batch_size, heads, d_k]

120        x = x.permute(3, 0, 1, 2)

123        return x

頭部ごとに複数の頭部深さ方向への注意喚起

Multi-Head Attentionの当初の実装を拡張し、クエリ、キー、バリュープロジェクションに空間深度方向のコンボリューションを追加します。

126class MultiDPHConvHeadAttention(MultiHeadAttention):
134    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
135        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query self.key 、およびを作成します。self.value

それぞれに頭ごとの深度方向の空間畳み込み層を組み合わせて、、、を置き換えます。self.query self.key self.value

142        self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143        self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
144        self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))