これは、論文「FNet: トークンをフーリエ変換と混合する」をPyTorchで実装したものです。
この論文では、セルフアテンション層を2つのフーリエ変換に置き換えてトークンを混合します。これは自己注意よりも効率的です。BERT on GLUE ベンチマークでは、自己注意よりもこれを使用した場合の精度の低下は約 92%
です。フーリエ変換を非表示次元 (埋め込み次元) に沿って適用し、次にシーケンス次元に沿って適用します。
ここで、は埋め込み入力で、フーリエ変換を表し、複素数の実数成分を表します。
これをPyTorchに実装するのはとても簡単です。たった1行のコードです。この論文では、事前に計算されたDFT行列を使用し、行列の乗算を行ってフーリエ変換を行うことを提案しています
。以下は、FNetベースのモデルを使用してAG Newsを分類するためのトレーニングコードです。
41from typing import Optional
42
43import torch
44from torch import nn
47class FNetMix(nn.Module):
通常のアテンションモジュールには、、、、マスクにさまざまなトークンを埋め込むことができます。
同じ関数シグネチャに従うので、直接置換できます。
FNetミキシングの場合、マスキングはできません。query
(key
とvalue
) の形はです[seq_len, batch_size, d_model]
。
60 def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
、、そしてトークンのミキシングではすべてが等しくなければなりません
72 assert query is key and key is value
トークンのミキシングはマスキングをサポートしていません。つまり、すべてのトークンに他のすべてのトークンの埋め込みが表示されます。
74 assert mask is None
x
わかりやすいように割り当てる
77 x = query
84 fft_hidden = torch.fft.fft(x, dim=2)
シーケンスの次元に沿ってフーリエ変換を適用します
87 fft_seq = torch.fft.fft(fft_hidden, dim=0)
本物のコンポーネントを手に入れよう
91 return torch.real(fft_seq)