12import torch
13from torch import nn
14
15from labml_helpers.module import Module
16from labml_nn.transformers import MultiHeadAttention
19class SpatialDepthWiseSharedConvolution(Module):
26 def __init__(self, kernel_size: int = 3):
27 super().__init__()
28 self.kernel_size = kernel_size
Conv1d
PyTorchのモジュールを使用しています。両側にパディングを追加し、kernel_size - 1
後で一番適切な結果になるようにトリミングします
33 self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))
x
形がある [seq_len, batch_size, heads, d_k]
35 def forward(self, x: torch.Tensor):
形を取得
41 seq_len, batch_size, heads, d_k = x.shape
に並べ替え [batch_size, heads, d_k, seq_len]
43 x = x.permute(1, 2, 3, 0)
形状を次のように変更 [batch_size * heads * d_k, seq_len]
45 x = x.view(batch_size * heads * d_k, 1, seq_len)
1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]
48 x = self.conv(x)
両側をパディングしたので、kernel_size - 1
最も適切な結果が得られるようにトリミングします
50 x = x[:, :, :-(self.kernel_size - 1)]
形状を次の形式に変更 [batch_size, heads, d_k, seq_len]
52 x = x.view(batch_size, heads, d_k, seq_len)
に並べ替え [seq_len, batch_size, heads, d_k]
54 x = x.permute(3, 0, 1, 2)
57 return x
60class MultiDSharedConvHeadAttention(MultiHeadAttention):
68 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
69 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query
self.key
、およびを作成します。self.value
それぞれに空間深度共有畳み込み層を組み合わせて、、、を置き換えますself.query
。self.key
self.value
76 self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
77 self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
78 self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())
81class SpatialDepthWisePerHeadConvolution(Module):
heads
は頭の数ですd_k
は各ヘッドのチャンネル数86 def __init__(self, heads: int, d_k: int, kernel_size: int = 3):
91 super().__init__()
92 self.kernel_size = kernel_size
Conv1d
PyTorchのモジュールを使用しています。グループの数を各ヘッドのチャネル数と同じになるように設定し、チャネルとヘッドごとに個別の畳み込みを (異なるカーネルで) 行います。両側にパディングを追加し、kernel_size - 1
後で一番適切な結果になるようにトリミングします
98 self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
99 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)
x
形がある [seq_len, batch_size, heads, d_k]
101 def forward(self, x: torch.Tensor):
形を取得
107 seq_len, batch_size, heads, d_k = x.shape
に並べ替え [batch_size, heads, d_k, seq_len]
109 x = x.permute(1, 2, 3, 0)
形状を次のように変更 [batch_size heads * d_k, seq_len]
111 x = x.view(batch_size, heads * d_k, seq_len)
1次元の畳み込みは次の形式の入力を受け付けます [N, channels, sequence]
114 x = self.conv(x)
両側をパディングしたので、kernel_size - 1
最も適切な結果が得られるようにトリミングします
116 x = x[:, :, :-(self.kernel_size - 1)]
形状を次の形式に変更 [batch_size, heads, d_k, seq_len]
118 x = x.view(batch_size, heads, d_k, seq_len)
に並べ替え [seq_len, batch_size, heads, d_k]
120 x = x.permute(3, 0, 1, 2)
123 return x
126class MultiDPHConvHeadAttention(MultiHeadAttention):
134 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
135 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query
self.key
、およびを作成します。self.value
それぞれに頭ごとの深度方向の空間畳み込み層を組み合わせて、、、を置き換えます。self.query
self.key
self.value
142 self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143 self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
144 self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))