This is a PyTorch implementation of the paper Pay Attention to MLPs.
This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of gMLP blocks.
Here is the training code for a gMLP model based autoregressive model.
21from typing import Optional
22
23import torch
24from torch import nn
Each block does the following transformations to input embeddings where is the sequence length and is the dimensionality of the embeddings:
where and are learnable projection weights. is the Spacial Gating Unit defined below. Output dimensionality of will be half of . is an activation function such as GeLU.
27class GMLPBlock(nn.Module):
d_model
is the dimensionality () of d_ffn
is the dimensionality of seq_len
is the length of the token sequence ()48 def __init__(self, d_model: int, d_ffn: int, seq_len: int):
54 super().__init__()
Normalization layer fro Pre-Norm
56 self.norm = nn.LayerNorm([d_model])
Activation function
58 self.activation = nn.GELU()
Projection layer for
60 self.proj1 = nn.Linear(d_model, d_ffn)
Spacial Gating Unit
62 self.sgu = SpacialGatingUnit(d_ffn, seq_len)
Projection layer for
64 self.proj2 = nn.Linear(d_ffn // 2, d_model)
Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.
68 self.size = d_model
x
is the input embedding tensor of shape [seq_len, batch_size, d_model]
mask
is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens among each other.70 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
Keep a copy for shortcut connection
77 shortcut = x
Normalize
79 x = self.norm(x)
Projection and activation
81 z = self.activation(self.proj1(x))
Spacial Gating Unit
83 z = self.sgu(z, mask)
Final projection
85 z = self.proj2(z)
Add the shortcut connection
88 return z + shortcut
where is a linear transformation along the sequence dimension, and is element-wise multiplication. is split into to parts of equal size and along the channel dimension (embedding dimension).
91class SpacialGatingUnit(nn.Module):
d_z
is the dimensionality of seq_len
is the sequence length101 def __init__(self, d_z: int, seq_len: int):
106 super().__init__()
Normalization layer before applying
108 self.norm = nn.LayerNorm([d_z // 2])
Weight in .
The paper notes that it's important to initialize weights to small values and the bias to , so that during the initial training is close to identity (apart from the split).
113 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
117 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
z
is the input of shape [seq_len, batch_size, d_z]
mask
is is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens among each other. The last dimension of size 1
is the batch, which we have in other transformer implementations and was left for compatibility.119 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
Get sequence length
128 seq_len = z.shape[0]
Split into and
130 z1, z2 = torch.chunk(z, 2, dim=-1)
Check mask
133 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
. The batch dimension should be of size 1
because this implementation supports only same mask for all samples in the batch.
137 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
138 assert mask.shape[1] == seq_len
Here we only support the same mask for all samples
140 assert mask.shape[2] == 1
Remove the batch dimension
142 mask = mask[:, :, 0]
Normalize before
145 z2 = self.norm(z2)
Get the weight matrix; truncate if larger than seq_len
147 weight = self.weight[:seq_len, :seq_len]
152 if mask is not None:
153 weight = weight * mask
156 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
159 return z1 * z2