This is a PyTorch implementation of the paper Primer: Searching for Efficient Transformers for Language Modeling.
The authors do an evolutionary search for transformer architectures. They name the architecture found using the search Primer (PRIMitives searched transformER). Primer EZ is the architecture with the two most robust modifications in Primer compared to the original transformer. Primer EZ trains a lot faster than the vanilla transformer.
The most effective modification found by the search is using a square ReLU instead of ReLU in the position-wise feedforward module.
The next effective modification is a depth-wise convolution after multi-head projection for queries, keys, and values. The convolution is along the sequence dimension and per channel (depth-wise). To be clear, if the number of channels in each head is the convolution will have kernels for each of the channels.
Here is the experiment code, for Primer EZ.
40import torch
41from torch import nn
42
43from labml_helpers.module import Module
44from labml_nn.transformers import MultiHeadAttention
Squared ReLU is used as the activation function in the position wise feedforward module.
47class SquaredReLU(Module):
57 def __init__(self):
58 super().__init__()
59 self.relu = nn.ReLU()
61 def forward(self, x: torch.Tensor):
Apply ReLU
63 x = self.relu(x)
Square it
65 return x * x
68class SpatialDepthWiseConvolution(Module):
d_k
is the number of channels in each head73 def __init__(self, d_k: int, kernel_size: int = 3):
77 super().__init__()
78 self.kernel_size = kernel_size
We use PyTorch's Conv1d
module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most kernel_size - 1
results
83 self.conv = nn.Conv1d(in_channels=d_k, out_channels=d_k,
84 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k)
x
has shape [seq_len, batch_size, heads, d_k]
86 def forward(self, x: torch.Tensor):
Get the shape
92 seq_len, batch_size, heads, d_k = x.shape
Permute to [batch_size, heads, d_k, seq_len]
94 x = x.permute(1, 2, 3, 0)
Change the shape to [batch_size * heads, d_k, seq_len]
96 x = x.view(batch_size * heads, d_k, seq_len)
1D convolution accepts input of the form [N, channels, sequence]
99 x = self.conv(x)
Crop the right most kernel_size - 1
results since we padded both sides
101 x = x[:, :, :-(self.kernel_size - 1)]
Reshape to [batch_size, heads, d_k, seq_len]
103 x = x.view(batch_size, heads, d_k, seq_len)
Permute to [seq_len, batch_size, heads, d_k]
105 x = x.permute(3, 0, 1, 2)
108 return x
We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.
111class MultiDConvHeadAttention(MultiHeadAttention):
119 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
120 super().__init__(heads, d_model, dropout_prob)
Multi-Head Attention will create query, key and value projection modules self.query
, self.key
, and self.value
.
We combine a spatial depth-wise convolution layer to each of them and replace self.query
, self.key
, and self.value
.
📝 We feel this cleaner implementation is easier to understand since it clearly shows the difference between this and vanilla transformer multi-head attention.
130 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
131 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
132 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))