1import math
2
3import torch
4from torch import nn
5
6from labml_nn.transformers import MultiHeadAttention

Spatial Depth Wise Convolution

This is actually slower

9class SpatialDepthWiseConvolution(nn.Module):
  • d_k is the number of channels in each head
16    def __init__(self, d_k: int, kernel_size: int = 3):
20        super().__init__()
21        self.kernel_size = kernel_size

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most kernel_size - 1 results

26        rng = 1 / math.sqrt(kernel_size)
27        self.kernels = nn.Parameter(torch.zeros((kernel_size, d_k)).uniform_(-rng, rng))

x has shape [seq_len, batch_size, heads, d_k]

29    def forward(self, x: torch.Tensor):
34        res = x * self.kernels[0].view(1, 1, 1, -1)
35
36        for i in range(1, len(self.kernels)):
37            res[i:] += x[:-i] * self.kernels[i].view(1, 1, 1, -1)
38
39        return res

Multi-DConv-Head Attention (MDHA)

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

42class MultiDConvHeadAttention(MultiHeadAttention):
50    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
51        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

58        self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
59        self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
60        self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))