We tried some variations to see which changes in Primer EZ has most benefits.

```
12import torch
13from torch import nn
14
15from labml_helpers.module import Module
16from labml_nn.transformers import MultiHeadAttention
```

`19class SpatialDepthWiseSharedConvolution(Module):`

```
26 def __init__(self, kernel_size: int = 3):
27 super().__init__()
28 self.kernel_size = kernel_size
```

We use PyTorch's `Conv1d`

module. We add padding to both sides and later crop the right most `kernel_size - 1`

results

`33 self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=(kernel_size,), padding=(kernel_size - 1,))`

`x`

has shape `[seq_len, batch_size, heads, d_k]`

`35 def forward(self, x: torch.Tensor):`

Get the shape

`41 seq_len, batch_size, heads, d_k = x.shape`

Permute to `[batch_size, heads, d_k, seq_len]`

`43 x = x.permute(1, 2, 3, 0)`

Change the shape to `[batch_size * heads * d_k, seq_len]`

`45 x = x.view(batch_size * heads * d_k, 1, seq_len)`

1D convolution accepts input of the form `[N, channels, sequence]`

`48 x = self.conv(x)`

Crop the right most `kernel_size - 1`

results since we padded both sides

`50 x = x[:, :, :-(self.kernel_size - 1)]`

Reshape to `[batch_size, heads, d_k, seq_len]`

`52 x = x.view(batch_size, heads, d_k, seq_len)`

Permute to `[seq_len, batch_size, heads, d_k]`

`54 x = x.permute(3, 0, 1, 2)`

`57 return x`

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.

`60class MultiDSharedConvHeadAttention(MultiHeadAttention):`

```
68 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
69 super().__init__(heads, d_model, dropout_prob)
```

Multi-Head Attention will create query, key and value projection modules `self.query`

, `self.key`

, and `self.value`

.

We combine a spatial depth-wise shared convolution layer to each of them and replace `self.query`

, `self.key`

, and `self.value`

.

```
76 self.query = nn.Sequential(self.query, SpatialDepthWiseSharedConvolution())
77 self.key = nn.Sequential(self.key, SpatialDepthWiseSharedConvolution())
78 self.value = nn.Sequential(self.value, SpatialDepthWiseSharedConvolution())
```

`81class SpatialDepthWisePerHeadConvolution(Module):`

`heads`

is the number of heads`d_k`

is the number of channels in each head

`86 def __init__(self, heads: int, d_k: int, kernel_size: int = 3):`

```
91 super().__init__()
92 self.kernel_size = kernel_size
```

We use PyTorch's `Conv1d`

module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most `kernel_size - 1`

results

```
98 self.conv = nn.Conv1d(in_channels=d_k * heads, out_channels=d_k * heads,
99 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k * heads)
```

`x`

has shape `[seq_len, batch_size, heads, d_k]`

`101 def forward(self, x: torch.Tensor):`

Get the shape

`107 seq_len, batch_size, heads, d_k = x.shape`

Permute to `[batch_size, heads, d_k, seq_len]`

`109 x = x.permute(1, 2, 3, 0)`

Change the shape to `[batch_size heads * d_k, seq_len]`

`111 x = x.view(batch_size, heads * d_k, seq_len)`

1D convolution accepts input of the form `[N, channels, sequence]`

`114 x = self.conv(x)`

Crop the right most `kernel_size - 1`

results since we padded both sides

`116 x = x[:, :, :-(self.kernel_size - 1)]`

Reshape to `[batch_size, heads, d_k, seq_len]`

`118 x = x.view(batch_size, heads, d_k, seq_len)`

Permute to `[seq_len, batch_size, heads, d_k]`

`120 x = x.permute(3, 0, 1, 2)`

`123 return x`

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

`126class MultiDPHConvHeadAttention(MultiHeadAttention):`

```
134 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
135 super().__init__(heads, d_model, dropout_prob)
```

Multi-Head Attention will create query, key and value projection modules `self.query`

, `self.key`

, and `self.value`

.

We combine a spatial per-head depth-wise convolution layer to each of them and replace `self.query`

, `self.key`

, and `self.value`

.

```
142 self.query = nn.Sequential(self.query, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
143 self.key = nn.Sequential(self.key, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
144 self.value = nn.Sequential(self.value, SpatialDepthWisePerHeadConvolution(heads, self.d_k))
```