We tried some variations to see which changes in Primer EZ has most benefits.
12import torch
13from torch import nn
14
15from labml_nn.transformers import MultiHeadAttention
18class SpatialDepthWiseSharedConvolution(nn.Module):
25 def __init__(self, kernel_size: int = 3):
26 super().__init__()
27 self.kernel_size = kernel_size
We use PyTorch's Conv1d
module. We add padding to both sides and later crop the right most kernel_size - 1
results
32 self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))
x
has shape [seq_len, batch_size, heads, d_k]
34 def forward(self, x: torch.Tensor):
Get the shape
40 seq_len, batch_size, heads, d_k = x.shape
Permute to [batch_size, heads, d_k, seq_len]
42 x = x.permute(1, 2, 3, 0)
Change the shape to [batch_size * heads * d_k, seq_len]
44 x = x.view(batch_size * heads * d_k, 1, seq_len)
1D convolution accepts input of the form [N, channels, sequence]
47 x = self.conv(x)
Crop the right most kernel_size - 1
results since we padded both sides
49 x = x[:, :, :-(self.kernel_size - 1)]
Reshape to [batch_size, heads, d_k, seq_len]
51 x = x.view(batch_size, heads, d_k, seq_len)
Permute to [seq_len, batch_size, heads, d_k]
53 x = x.permute(3, 0, 1, 2)
56 return x
We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.
59class MultiDSharedConvHeadAttention(MultiHeadAttention):
67 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
68 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention will create query, key and value projection modules self.query
, self.key
, and self.value
.
We combine a spatial depth-wise shared convolution layer to each of them and replace self.query
, self.key
, and self.value
.
75 self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
76 self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
77 self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())
80class SpatialDepthWisePerHeadConvolution(nn.Module):
heads
is the number of heads d_k
is the number of channels in each head85 def __init__(self, heads: int, d_k: int, kernel_size: int = 3):
90 super().__init__()
91 self.kernel_size = kernel_size
We use PyTorch's Conv1d
module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most kernel_size - 1
results
97 self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
98 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)
x
has shape [seq_len, batch_size, heads, d_k]
100 def forward(self, x: torch.Tensor):
Get the shape
106 seq_len, batch_size, heads, d_k = x.shape
Permute to [batch_size, heads, d_k, seq_len]
108 x = x.permute(1, 2, 3, 0)
Change the shape to [batch_size heads * d_k, seq_len]
110 x = x.view(batch_size, heads * d_k, seq_len)
1D convolution accepts input of the form [N, channels, sequence]
113 x = self.conv(x)
Crop the right most kernel_size - 1
results since we padded both sides
115 x = x[:, :, :-(self.kernel_size - 1)]
Reshape to [batch_size, heads, d_k, seq_len]
117 x = x.view(batch_size, heads, d_k, seq_len)
Permute to [seq_len, batch_size, heads, d_k]
119 x = x.permute(3, 0, 1, 2)
122 return x
We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.
125class MultiDPHConvHeadAttention(MultiHeadAttention):
133 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
134 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention will create query, key and value projection modules self.query
, self.key
, and self.value
.
We combine a spatial per-head depth-wise convolution layer to each of them and replace self.query
, self.key
, and self.value
.
141 self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
142 self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143 self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))