Pay Attention to MLPs (gMLP)

This is a PyTorch implementation of the paper Pay Attention to MLPs.

This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of $L$ gMLP blocks.

Here is the training code for a gMLP model based autoregressive model.

View Run

21from typing import Optional
23import torch
24from torch import nn

gMLP Block

Each block does the following transformations to input embeddings $X \in \mathbb{R}^{n \times d}$ where $n$ is the sequence length and $d$ is the dimensionality of the embeddings:

where $V$ and $U$ are learnable projection weights. $s(\cdot)$ is the Spacial Gating Unit defined below. Output dimensionality of $s(\cdot)$ will be half of $Z$. $\sigma$ is an activation function such as GeLU.

27class GMLPBlock(nn.Module):

d_model is the dimensionality ($d$) of $X$ d_ffn is the dimensionality of $Z$ seq_len is the length of the token sequence ($n$)

48    def __init__(self, d_model: int, d_ffn: int, seq_len: int):
54        super().__init__()

Normalization layer fro Pre-Norm

56        self.norm = nn.LayerNorm([d_model])

Activation function $\sigma$

58        self.activation = nn.GELU()

Projection layer for $Z = \sigma(XU)$

60        self.proj1 = nn.Linear(d_model, d_ffn)

Spacial Gating Unit $s(\cdot)$

62        self.sgu = SpacialGatingUnit(d_ffn, seq_len)

Projection layer for $Y = \tilde{Z}V$

64        self.proj2 = nn.Linear(d_ffn // 2, d_model)

Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.

68        self.size = d_model
  • x is the input embedding tensor $X$ of shape [seq_len, batch_size, d_model]
  • mask is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other.
70    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):

Keep a copy for shortcut connection

77        shortcut = x

Normalize $X$

79        x = self.norm(x)

Projection and activation $Z = \sigma(XU)$

81        z = self.activation(self.proj1(x))

Spacial Gating Unit $\tilde{Z} = s(Z)$

83        z = self.sgu(z, mask)

Final projection $Y = \tilde{Z}V$

85        z = self.proj2(z)

Add the shortcut connection

88        return z + shortcut

Spatial Gating Unit

where $f_{W,b}(Z) = W Z + b$ is a linear transformation along the sequence dimension, and $\odot$ is element-wise multiplication. $Z$ is split into to parts of equal size $Z_1$ and $Z_2$ along the channel dimension (embedding dimension).

91class SpacialGatingUnit(nn.Module):
  • d_z is the dimensionality of $Z$
  • seq_len is the sequence length
101    def __init__(self, d_z: int, seq_len: int):
106        super().__init__()

Normalization layer before applying $f_{W,b}(\cdot)$

108        self.norm = nn.LayerNorm([d_z // 2])

Weight $W$ in $f_{W,b}(\cdot)$.

The paper notes that it’s important to initialize weights to small values and the bias to $1$, so that during the initial training $s(\cdot)$ is close to identity (apart from the split).

113        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)

Weight $b$ in $f_{W,b}(\cdot)$

The paper notes that it’s important to initialize bias to $1$.

117        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
  • z is the input $Z$ of shape [seq_len, batch_size, d_z]
  • mask is is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other. The last dimension of size 1 is the batch, which we have in other transformer implementations and was left for compatibility.
119    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):

Get sequence length

128        seq_len = z.shape[0]

Split $Z$ into $Z_1$ and $Z_2$

130        z1, z2 = torch.chunk(z, 2, dim=-1)

Check mask

133        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size]. The batch dimension should be of size 1 because this implementation supports only same mask for all samples in the batch.

137            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
138            assert mask.shape[1] == seq_len

Here we only support the same mask for all samples

140            assert mask.shape[2] == 1

Remove the batch dimension

142            mask = mask[:, :, 0]

Normalize $Z_2$ before $f_{W,b}(\cdot)$

145        z2 = self.norm(z2)

Get the weight matrix; truncate if larger than seq_len

147        weight = self.weight[:seq_len, :seq_len]

Apply mask to the weights.

If $W_{i,j}$ is $0$ then $f_{W,b}(Z_2)_i$ will not get any information from token $j$.

152        if mask is not None:
153            weight = weight * mask

$f_{W,b}(Z_2) = W Z_2 + b$

156        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]

$Z_1 \odot f_{W,b}(Z_2)$

159        return z1 * z2