1import math
2
3import torch
4from torch import nn
5
6from labml_nn.transformers import MultiHeadAttention9class SpatialDepthWiseConvolution(nn.Module):d_k
is the number of channels in each head16 def __init__(self, d_k: int, kernel_size: int = 3):20 super().__init__()
21 self.kernel_size = kernel_sizeWe use PyTorch's Conv1d
module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most kernel_size - 1
results
26 rng = 1 / math.sqrt(kernel_size)
27 self.kernels = nn.Parameter(torch.zeros((kernel_size, d_k)).uniform_(-rng, rng)) x
has shape [seq_len, batch_size, heads, d_k]
29 def forward(self, x: torch.Tensor):34 res = x * self.kernels[0].view(1, 1, 1, -1)
35
36 for i in range(1, len(self.kernels)):
37 res[i:] += x[:-i] * self.kernels[i].view(1, 1, 1, -1)
38
39 return resWe extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.
42class MultiDConvHeadAttention(MultiHeadAttention):50 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
51 super().__init__(heads, d_model, dropout_prob)Multi-Head Attention will create query, key and value projection modules self.query
, self.key
, and self.value
.
We combine a spatial depth-wise convolution layer to each of them and replace self.query
, self.key
, and self.value
.
58 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
59 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
60 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))