Pay Attention to MLPs (gMLP)

This is a PyTorch implementation of the paper Pay Attention to MLPs.

This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of gMLP blocks.

Here is the training code for a gMLP model based autoregressive model.

View Run

21from typing import Optional
23import torch
24from torch import nn

gMLP Block

Each block does the following transformations to input embeddings where is the sequence length and is the dimensionality of the embeddings:

where and are learnable projection weights. is the Spacial Gating Unit defined below. Output dimensionality of will be half of . is an activation function such as GeLU.

27class GMLPBlock(nn.Module):
  • d_model is the dimensionality () of
  • d_ffn is the dimensionality of
  • seq_len is the length of the token sequence ()
48    def __init__(self, d_model: int, d_ffn: int, seq_len: int):
54        super().__init__()

Normalization layer fro Pre-Norm

56        self.norm = nn.LayerNorm([d_model])

Activation function

58        self.activation = nn.GELU()

Projection layer for

60        self.proj1 = nn.Linear(d_model, d_ffn)

Spacial Gating Unit

62        self.sgu = SpacialGatingUnit(d_ffn, seq_len)

Projection layer for

64        self.proj2 = nn.Linear(d_ffn // 2, d_model)

Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.

68        self.size = d_model
  • x is the input embedding tensor of shape [seq_len, batch_size, d_model]
  • mask is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other.
70    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):

Keep a copy for shortcut connection

77        shortcut = x


79        x = self.norm(x)

Projection and activation

81        z = self.activation(self.proj1(x))

Spacial Gating Unit

83        z = self.sgu(z, mask)

Final projection

85        z = self.proj2(z)

Add the shortcut connection

88        return z + shortcut

Spatial Gating Unit

where is a linear transformation along the sequence dimension, and is element-wise multiplication. is split into to parts of equal size and along the channel dimension (embedding dimension).

91class SpacialGatingUnit(nn.Module):
  • d_z is the dimensionality of
  • seq_len is the sequence length
101    def __init__(self, d_z: int, seq_len: int):
106        super().__init__()

Normalization layer before applying

108        self.norm = nn.LayerNorm([d_z // 2])

Weight in .

The paper notes that it's important to initialize weights to small values and the bias to , so that during the initial training is close to identity (apart from the split).

113        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)

Weight in

The paper notes that it's important to initialize bias to .

117        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
  • z is the input of shape [seq_len, batch_size, d_z]
  • mask is is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other. The last dimension of size 1 is the batch, which we have in other transformer implementations and was left for compatibility.
119    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):

Get sequence length

128        seq_len = z.shape[0]

Split into and

130        z1, z2 = torch.chunk(z, 2, dim=-1)

Check mask

133        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size] . The batch dimension should be of size 1 because this implementation supports only same mask for all samples in the batch.

137            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
138            assert mask.shape[1] == seq_len

Here we only support the same mask for all samples

140            assert mask.shape[2] == 1

Remove the batch dimension

142            mask = mask[:, :, 0]

Normalize before

145        z2 = self.norm(z2)

Get the weight matrix; truncate if larger than seq_len

147        weight = self.weight[:seq_len, :seq_len]

Apply mask to the weights.

If is then will not get any information from token .

152        if mask is not None:
153            weight = weight * mask

156        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]

159        return z1 * z2