This is a PyTorch implementation of the paper FNet: Mixing Tokens with Fourier Transforms.

This paper replaces the self-attention layer with two Fourier transforms to *mix* tokens. This is a $7×$ more efficient than self-attention. The accuracy loss of using this over self-attention is about 92% for BERT on GLUE benchmark.

We apply Fourier transform along the hidden dimension (embedding dimension) and then along the sequence dimension.

$R(F_{seq}(F_{hidden}(x)))$

where $x$ is the embedding input, $F$ stands for the fourier transform and $R$ stands for the real component in complex numbers.

This is very simple to implement on PyTorch - just 1 line of code. The paper suggests using a precomputed DFT matrix and doing matrix multiplication to get the Fourier transformation.

Here is the training code for using a FNet based model for classifying AG News.

```
41from typing import Optional
42
43import torch
44from torch import nn
```

This module simply implements $R(F_{seq}(F_{hidden}(x)))$

The structure of this module is made similar to a standard attention module so that we can simply replace it.

`47class FNetMix(nn.Module):`

The normal attention module can be fed with different token embeddings for $query$,$key$, and $value$ and a mask.

We follow the same function signature so that we can replace it directly.

For FNet mixing, $x=query=key=value$ and masking is not possible. Shape of `query`

(and `key`

and `value`

) is `[seq_len, batch_size, d_model]`

.

`60 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):`

$query$,$key$, and $value$ all should be equal to $x$ for token mixing

`72 assert query is key and key is value`

Token mixing doesn't support masking. i.e. all tokens will see all other token embeddings.

`74 assert mask is None`

Assign to `x`

for clarity

`77 x = query`

Apply the Fourier transform along the hidden (embedding) dimension $F_{hidden}(x)$

The output of the Fourier transform is a tensor of complex numbers.

`84 fft_hidden = torch.fft.fft(x, dim=2)`

Apply the Fourier transform along the sequence dimension $F_{seq}(F_{hidden}(x))$

`87 fft_seq = torch.fft.fft(fft_hidden, dim=0)`

Get the real component $R(F_{seq}(F_{hidden}(x)))$

`91 return torch.real(fft_seq)`