FNet: フーリエ変換によるトークンの混合

これは、論文「FNet: トークンをフーリエ変換と混合する」をPyTorchで実装したものです

この論文では、セルフアテンション層を2つのフーリエ変換に置き換えてトークンを混合しますこれは自己注意よりも効率的です。BERT on GLUE ベンチマークでは、自己注意よりもこれを使用した場合の精度の低下は約 92%

です。

2 つのフーリエ変換によるトークンの混合

フーリエ変換を非表示次元 (埋め込み次元) に沿って適用し、次にシーケンス次元に沿って適用します。

ここで、は埋め込み入力で、フーリエ変換を表し、複素数の実数成分を表します。

これをPyTorchに実装するのはとても簡単です。たった1行のコードです。この論文では、事前に計算されたDFT行列を使用し、行列の乗算を行ってフーリエ変換を行うことを提案しています

以下はFNetベースのモデルを使用してAG Newsを分類するためのトレーニングコードです。

41from typing import Optional
42
43import torch
44from torch import nn

FNet-ミックストークン

このモジュールは単純に実装します

このモジュールの構造は、標準的なアテンションモジュールと同様の構造になっているため、簡単に交換できます。

47class FNetMix(nn.Module):

通常のアテンションモジュールには、、、マスクにさまざまなトークンを埋め込むことができます。

同じ関数シグネチャに従うので、直接置換できます。

FNetミキシングの場合、マスキングはできません。query (keyvalue ) の形はです[seq_len, batch_size, d_model]

60    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

そしてトークンのミキシングではすべてが等しくなければなりません

72        assert query is key and key is value

トークンのミキシングはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークンの埋め込みが表示されます。

74        assert mask is None

x わかりやすいように割り当てる

77        x = query

フーリエ変換を非表示 (埋め込み) 次元に沿って適用します

フーリエ変換の出力は複素数のテンソルです

84        fft_hidden = torch.fft.fft(x, dim=2)

シーケンスの次元に沿ってフーリエ変換を適用します

87        fft_seq = torch.fft.fft(fft_hidden, dim=0)

本物のコンポーネントを手に入れよう

91        return torch.real(fft_seq)