ファジータイリングアクティベーション (FTA)

Open In Colab

これは、ファジータイリングアクティベーション:スパース表現をオンラインで学習するための簡単なアプローチの PyTorch 実装/チュートリアルです。

ファジータイリングアクティベーションは、ビニングに基づくスパースアクティベーションの一種です。

ビニングとは、間隔に基づいてスカラー値をビンに分類することです。ビニングの問題の 1 つは、ほとんどの値 (ビンの境界を除く) でグラデーションがゼロになることです。もう1つは、ビンの間隔が大きいとビニングの精度が低下することです

FTAはこれらの欠点を克服します。FTAはタイリングアクティベーションのようなハードバウンダリーの代わりに、ビンの間にソフトバウンダリーを使います。これにより、すべてまたは広範囲の値に対してゼロ以外のグラデーションが得られます。また、部分的な値でキャプチャされるため、精度が失われることはありません。

タイリングアクティベーション

はタイリングベクトル、

ここで、は入力範囲、はビンのサイズ、で割り切れます。

タイリングアクティベーションは、

ここで、入力が正のかどうかを示すインジケーター関数と、そうでないかどうかを示すインジケーター関数があります。

タイリングを有効にすると、境界が固いため、グラデーションはゼロになることに注意してください。

ファジータイリングアクティベーション

ファジーインジケーター機能、

これはからいつまで直線的に増加し、 for と等しくなります。ハイパーパラメータです

FTA はこれを使ってビンの間にソフトな境界線を作ります。

これは、変圧器でFTAを使用する簡単な実験です

61import torch
62from torch import nn

ファジータイリングアクティベーション (FTA)

65class FTA(nn.Module):
  • lower_limit 下限です
  • upper_limit 上限です
  • delta はビンのサイズです
  • eta 境界の柔らかさを決定するパラメータです。
  • 70    def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
    77        super().__init__()

    タイリングベクトルを初期化

    80        self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)

    入力ベクトルは、ビンの数と同じ係数だけ拡大されます。

    82        self.expansion_factor = len(self.c)

    84        self.delta = delta

    86        self.eta = eta

    ファジーインジケーター機能

    88    def fuzzy_i_plus(self, x: torch.Tensor):
    94        return (x <= self.eta) * x + (x > self.eta)
    96    def forward(self, z: torch.Tensor):

    サイズをもう1つ追加してください。これをビンに拡張します

    99        z = z.view(*z.shape, 1)

    102        z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))

    元の寸法数に戻します。最後のディメンションサイズはビンの数だけ拡張されます

    106        return z.view(*z.shape[:-2], -1)

    FTA モジュールをテストするコード

    109def _test():
    113    from labml.logger import inspect

    [初期化]

    116    a = FTA(-10, 10, 2., 0.5)

    プリント

    118    inspect(a.c)

    ビンの数を印刷

    120    inspect(a.expansion_factor)

    [入力]

    123    z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])

    プリント

    125    inspect(z)

    プリント

    127    inspect(a(z))
    128
    129
    130if __name__ == '__main__':
    131    _test()