This is a PyTorch implementation of the paper FNet: Mixing Tokens with Fourier Transforms.
This paper replaces the self-attention layer with two Fourier transforms to mix tokens. This is a more efficient than self-attention. The accuracy loss of using this over self-attention is about 92% for BERT on GLUE benchmark.
We apply Fourier transform along the hidden dimension (embedding dimension) and then along the sequence dimension.
where is the embedding input, stands for the fourier transform and stands for the real component in complex numbers.
This is very simple to implement on PyTorch - just 1 line of code. The paper suggests using a precomputed DFT matrix and doing matrix multiplication to get the Fourier transformation.
Here is the training code for using a FNet based model for classifying AG News.
41from typing import Optional
42
43import torch
44from torch import nn
This module simply implements
The structure of this module is made similar to a standard attention module so that we can simply replace it.
47class FNetMix(nn.Module):
The normal attention module can be fed with different token embeddings for ,, and and a mask.
We follow the same function signature so that we can replace it directly.
For FNet mixing, and masking is not possible. Shape of query
(and key
and value
) is [seq_len, batch_size, d_model]
.
60 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
,, and all should be equal to for token mixing
72 assert query is key and key is value
Token mixing doesn't support masking. i.e. all tokens will see all other token embeddings.
74 assert mask is None
Assign to x
for clarity
77 x = query
Apply the Fourier transform along the hidden (embedding) dimension
The output of the Fourier transform is a tensor of complex numbers.
84 fft_hidden = torch.fft.fft(x, dim=2)
Apply the Fourier transform along the sequence dimension
87 fft_seq = torch.fft.fft(fft_hidden, dim=0)
Get the real component
91 return torch.real(fft_seq)