This is a PyTorch implementation of the paper Pay Attention to MLPs.

This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name **gMLP**. It consists of a stack of $L$ *gMLP* blocks.

Here is the training code for a gMLP model based autoregressive model.

```
19from typing import Optional
20
21import torch
22from torch import nn
```

Each block does the following transformations to input embeddings $X∈R_{n×d}$ where $n$ is the sequence length and $d$ is the dimensionality of the embeddings:

$ZZ~Y =σ(XU)=s(Z)=Z~V $where $V$ and $U$ are learnable projection weights. $s(⋅)$ is the Spacial Gating Unit defined below. Output dimensionality of $s(⋅)$ will be half of $Z$. $σ$ is an activation function such as GeLU.

`25class GMLPBlock(nn.Module):`

`d_model`

is the dimensionality ($d$) of $X$`d_ffn`

is the dimensionality of $Z$`seq_len`

is the length of the token sequence ($n$)

`46 def __init__(self, d_model: int, d_ffn: int, seq_len: int):`

`52 super().__init__()`

Normalization layer fro Pre-Norm

`54 self.norm = nn.LayerNorm([d_model])`

Activation function $σ$

`56 self.activation = nn.GELU()`

Projection layer for $Z=σ(XU)$

`58 self.proj1 = nn.Linear(d_model, d_ffn)`

Spacial Gating Unit $s(⋅)$

`60 self.sgu = SpacialGatingUnit(d_ffn, seq_len)`

Projection layer for $Y=Z~V$

`62 self.proj2 = nn.Linear(d_ffn // 2, d_model)`

Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug *gMLP* block as a replacement for the Transformer Layer.

`66 self.size = d_model`

`x`

is the input embedding tensor $X$ of shape`[seq_len, batch_size, d_model]`

`mask`

is a boolean mask of shape`[seq_len, seq_len, 1]`

that controls the visibility of tokens among each other.

`68 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):`

Keep a copy for shortcut connection

`75 shortcut = x`

Normalize $X$

`77 x = self.norm(x)`

Projection and activation $Z=σ(XU)$

`79 z = self.activation(self.proj1(x))`

Spacial Gating Unit $Z~=s(Z)$

`81 z = self.sgu(z, mask)`

Final projection $Y=Z~V$

`83 z = self.proj2(z)`

Add the shortcut connection

`86 return z + shortcut`

$s(Z)=Z_{1}⊙f_{W,b}(Z_{2})$

where $f_{W,b}(Z)=WZ+b$ is a linear transformation along the sequence dimension, and $⊙$ is element-wise multiplication. $Z$ is split into to parts of equal size $Z_{1}$ and $Z_{2}$ along the channel dimension (embedding dimension).

`89class SpacialGatingUnit(nn.Module):`

`d_z`

is the dimensionality of $Z$`seq_len`

is the sequence length

`99 def __init__(self, d_z: int, seq_len: int):`

`104 super().__init__()`

Normalization layer before applying $f_{W,b}(⋅)$

`106 self.norm = nn.LayerNorm([d_z // 2])`

Weight $W$ in $f_{W,b}(⋅)$.

The paper notes that it's important to initialize weights to small values and the bias to $1$, so that during the initial training $s(⋅)$ is close to identity (apart from the split).

`111 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)`

`115 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)`

`z`

is the input $Z$ of shape`[seq_len, batch_size, d_z]`

`mask`

is is a boolean mask of shape`[seq_len, seq_len, 1]`

that controls the visibility of tokens among each other. The last dimension of size`1`

is the batch, which we have in other transformer implementations and was left for compatibility.

`117 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):`

Get sequence length

`126 seq_len = z.shape[0]`

Split $Z$ into $Z_{1}$ and $Z_{2}$

`128 z1, z2 = torch.chunk(z, 2, dim=-1)`

Check mask

`131 if mask is not None:`

`mask`

has shape `[seq_len_q, seq_len_k, batch_size]`

. The batch dimension should be of size `1`

because this implementation supports only same mask for all samples in the batch.

```
135 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136 assert mask.shape[1] == seq_len
```

Here we only support the same mask for all samples

`138 assert mask.shape[2] == 1`

Remove the batch dimension

`140 mask = mask[:, :, 0]`

Normalize $Z_{2}$ before $f_{W,b}(⋅)$

`143 z2 = self.norm(z2)`

Get the weight matrix; truncate if larger than `seq_len`

`145 weight = self.weight[:seq_len, :seq_len]`

Apply mask to the weights.

If $W_{i,j}$ is $0$ then $f_{W,b}(Z_{2})_{i}$ will not get any information from token $j$.

```
150 if mask is not None:
151 weight = weight * mask
```

$f_{W,b}(Z_{2})=WZ_{2}+b$

`154 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]`

$Z_{1}⊙f_{W,b}(Z_{2})$

`157 return z1 * z2`