This is a PyTorch implementation of the paper FNet: Mixing Tokens with Fourier Transforms.

This paper replaces the self-attention layer with two
Fourier transforms to
*mix* tokens.
This is a $7 \times$ more efficient than self-attention.
The accuracy loss of using this over self-attention is about 92% for
BERT on
GLUE benchmark.

We apply Fourier transform along the hidden dimension (embedding dimension) and then along the sequence dimension.

where $x$ is the embedding input, $\mathcal{F}$ stands for the fourier transform and $\mathcal{R}$ stands for the real component in complex numbers.

This is very simple to implement on PyTorch - just 1 line of code. The paper suggests using a precomputed DFT matrix and doing matrix multiplication to get the Fourier transformation.

Here is the training code for using a FNet based model for classifying AG News.

```
41from typing import Optional
42
43import torch
44from torch import nn
```

This module simply implements

The structure of this module is made similar to a standard attention module so that we can simply replace it.

`47class FNetMix(nn.Module):`

The normal attention module can be fed with different token embeddings for $\text{query}$,$\text{key}$, and $\text{value}$ and a mask.

We follow the same function signature so that we can replace it directly.

For FNet mixing, and masking is not possible.
Shape of `query`

(and `key`

and `value`

) is `[seq_len, batch_size, d_model]`

.

`60 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):`

$\text{query}$,$\text{key}$, and $\text{value}$ all should be equal to $x$ for token mixing

`72 assert query is key and key is value`

Token mixing doesn’t support masking. i.e. all tokens will see all other token embeddings.

`74 assert mask is None`

Assign to `x`

for clarity

`77 x = query`

Apply the Fourier transform along the hidden (embedding) dimension

The output of the Fourier transform is a tensor of complex numbers.

`84 fft_hidden = torch.fft.fft(x, dim=2)`

Apply the Fourier transform along the sequence dimension

`87 fft_seq = torch.fft.fft(fft_hidden, dim=0)`

Get the real component

`91 return torch.real(fft_seq)`