Primer EZ Variations

We tried some variations to see which changes in Primer EZ has most benefits.

12import torch
13from torch import nn
14
15from labml_helpers.module import Module
16from labml_nn.transformers import MultiHeadAttention

Spatial Depth Wise Shared Convolution

We share the same kernel across all channels.

19class SpatialDepthWiseSharedConvolution(Module):
26    def __init__(self, kernel_size: int = 3):
27        super().__init__()
28        self.kernel_size = kernel_size

We use PyTorch's Conv1d module. We add padding to both sides and later crop the right most kernel_size - 1 results

33        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))

x has shape [seq_len, batch_size, heads, d_k]

35    def forward(self, x: torch.Tensor):

Get the shape

41        seq_len, batch_size, heads, d_k = x.shape

Permute to [batch_size, heads, d_k, seq_len]

43        x = x.permute(1, 2, 3, 0)

Change the shape to [batch_size * heads * d_k, seq_len]

45        x = x.view(batch_size * heads * d_k, 1, seq_len)

1D convolution accepts input of the form [N, channels, sequence]

48        x = self.conv(x)

Crop the right most kernel_size - 1 results since we padded both sides

50        x = x[:, :, :-(self.kernel_size - 1)]

Reshape to [batch_size, heads, d_k, seq_len]

52        x = x.view(batch_size, heads, d_k, seq_len)

Permute to [seq_len, batch_size, heads, d_k]

54        x = x.permute(3, 0, 1, 2)

57        return x

Multi-Depth-wise-Shared-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.

60class MultiDSharedConvHeadAttention(MultiHeadAttention):
68    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
69        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise shared convolution layer to each of them and replace self.query , self.key , and self.value .

76        self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
77        self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
78        self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())

Spatial Depth Wise Per Head Convolution

81class SpatialDepthWisePerHeadConvolution(Module):
  • heads is the number of heads
  • d_k is the number of channels in each head
86    def __init__(self, heads: int, d_k: int, kernel_size: int = 3):
91        super().__init__()
92        self.kernel_size = kernel_size

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most kernel_size - 1 results

98        self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
99                              kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)

x has shape [seq_len, batch_size, heads, d_k]

101    def forward(self, x: torch.Tensor):

Get the shape

107        seq_len, batch_size, heads, d_k = x.shape

Permute to [batch_size, heads, d_k, seq_len]

109        x = x.permute(1, 2, 3, 0)

Change the shape to [batch_size heads * d_k, seq_len]

111        x = x.view(batch_size, heads * d_k, seq_len)

1D convolution accepts input of the form [N, channels, sequence]

114        x = self.conv(x)

Crop the right most kernel_size - 1 results since we padded both sides

116        x = x[:, :, :-(self.kernel_size - 1)]

Reshape to [batch_size, heads, d_k, seq_len]

118        x = x.view(batch_size, heads, d_k, seq_len)

Permute to [seq_len, batch_size, heads, d_k]

120        x = x.permute(3, 0, 1, 2)

123        return x

Multi-per-Head-Depth-wise-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

126class MultiDPHConvHeadAttention(MultiHeadAttention):
134    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
135        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial per-head depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

142        self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143        self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
144        self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))