FNet: Mixing Tokens with Fourier Transforms

This is a PyTorch implementation of the paper FNet: Mixing Tokens with Fourier Transforms.

This paper replaces the self-attention layer with two Fourier transforms to mix tokens. This is a more efficient than self-attention. The accuracy loss of using this over self-attention is about 92% for BERT on GLUE benchmark.

Mixing tokens with two Fourier transforms

We apply Fourier transform along the hidden dimension (embedding dimension) and then along the sequence dimension.

where is the embedding input, stands for the fourier transform and stands for the real component in complex numbers.

This is very simple to implement on PyTorch - just 1 line of code. The paper suggests using a precomputed DFT matrix and doing matrix multiplication to get the Fourier transformation.

Here is the training code for using a FNet based model for classifying AG News.

41from typing import Optional
43import torch
44from torch import nn

FNet - Mix tokens

This module simply implements

The structure of this module is made similar to a standard attention module so that we can simply replace it.

47class FNetMix(nn.Module):

The normal attention module can be fed with different token embeddings for ,, and and a mask.

We follow the same function signature so that we can replace it directly.

For FNet mixing, and masking is not possible. Shape of query (and key and value ) is [seq_len, batch_size, d_model] .

60    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

,, and all should be equal to for token mixing

72        assert query is key and key is value

Token mixing doesn't support masking. i.e. all tokens will see all other token embeddings.

74        assert mask is None

Assign to x for clarity

77        x = query

Apply the Fourier transform along the hidden (embedding) dimension

The output of the Fourier transform is a tensor of complex numbers.

84        fft_hidden = torch.fft.fft(x, dim=2)

Apply the Fourier transform along the sequence dimension

87        fft_seq = torch.fft.fft(fft_hidden, dim=0)

Get the real component

91        return torch.real(fft_seq)