Primer EZ Variations

We tried some variations to see which changes in Primer EZ has most benefits.

12import torch
13from torch import nn
14
15from labml_nn.transformers import MultiHeadAttention

Spatial Depth Wise Shared Convolution

We share the same kernel across all channels.

18class SpatialDepthWiseSharedConvolution(nn.Module):
25    def __init__(self, kernel_size: int = 3):
26        super().__init__()
27        self.kernel_size = kernel_size

We use PyTorch's Conv1d module. We add padding to both sides and later crop the right most kernel_size - 1 results

32        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))

x has shape [seq_len, batch_size, heads, d_k]

34    def forward(self, x: torch.Tensor):

Get the shape

40        seq_len, batch_size, heads, d_k = x.shape

Permute to [batch_size, heads, d_k, seq_len]

42        x = x.permute(1, 2, 3, 0)

Change the shape to [batch_size * heads * d_k, seq_len]

44        x = x.view(batch_size * heads * d_k, 1, seq_len)

1D convolution accepts input of the form [N, channels, sequence]

47        x = self.conv(x)

Crop the right most kernel_size - 1 results since we padded both sides

49        x = x[:, :, :-(self.kernel_size - 1)]

Reshape to [batch_size, heads, d_k, seq_len]

51        x = x.view(batch_size, heads, d_k, seq_len)

Permute to [seq_len, batch_size, heads, d_k]

53        x = x.permute(3, 0, 1, 2)

56        return x

Multi-Depth-wise-Shared-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.

59class MultiDSharedConvHeadAttention(MultiHeadAttention):
67    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
68        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise shared convolution layer to each of them and replace self.query , self.key , and self.value .

75        self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
76        self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
77        self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())

Spatial Depth Wise Per Head Convolution

80class SpatialDepthWisePerHeadConvolution(nn.Module):
  • heads is the number of heads
  • d_k is the number of channels in each head
85    def __init__(self, heads: int, d_k: int, kernel_size: int = 3):
90        super().__init__()
91        self.kernel_size = kernel_size

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most kernel_size - 1 results

97        self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
98                              kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)

x has shape [seq_len, batch_size, heads, d_k]

100    def forward(self, x: torch.Tensor):

Get the shape

106        seq_len, batch_size, heads, d_k = x.shape

Permute to [batch_size, heads, d_k, seq_len]

108        x = x.permute(1, 2, 3, 0)

Change the shape to [batch_size heads * d_k, seq_len]

110        x = x.view(batch_size, heads * d_k, seq_len)

1D convolution accepts input of the form [N, channels, sequence]

113        x = self.conv(x)

Crop the right most kernel_size - 1 results since we padded both sides

115        x = x[:, :, :-(self.kernel_size - 1)]

Reshape to [batch_size, heads, d_k, seq_len]

117        x = x.view(batch_size, heads, d_k, seq_len)

Permute to [seq_len, batch_size, heads, d_k]

119        x = x.permute(3, 0, 1, 2)

122        return x

Multi-per-Head-Depth-wise-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

125class MultiDPHConvHeadAttention(MultiHeadAttention):
133    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
134        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial per-head depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

141        self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
142        self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143        self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))