レイヤー正規化

これはレイヤー正規化の PyTorch 実装です。

バッチ正規化の制限事項

  • ランニング手段を維持する必要があります。
  • RNNにとっては扱いにくい。ステップごとに異なる正規化が必要ですか
  • ?
  • 小さなバッチサイズでは機能しません。大規模なNLPモデルは通常、小さなバッチサイズでトレーニングされます。
  • 分散型トレーニングでは、デバイス間の平均と分散を計算する必要があります。
  • レイヤー正規化

    レイヤー正規化は、より幅広い設定に適用できる、より単純な正規化方法です。層の正規化により、入力は特徴全体で平均がゼロで単位分散がなくなるように変換されます。バッチ正規化では、各要素のゼロ平均と単位分散が固定されることに注意してください。レイヤーの正規化は、すべての要素のバッチごとに正規化を行います

    レイヤー正規化は通常、NLP タスクに使用されます。

    ほとんどのトランスフォーマー実装で層の正規化を使用しています

    35from typing import Union, List
    36
    37import torch
    38from torch import nn, Size
    39
    40from labml_helpers.module import Module

    レイヤー正規化

    層の正規化は、入力を次のように正規化します。

    入力が埋め込みのバッチの場合、はバッチサイズ、はフィーチャの数です。

    入力が埋め込みシーケンスのバッチである場合、はバッチサイズ、はチャネル数、はシーケンスの長さです。

    入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。これはあまり使われていないシナリオです。

    43class LayerNorm(Module):
    • normalized_shape 要素の形状です (バッチは除く)。その場合、入力は次のようになります。
    • eps 数値の安定性のために使用されます
    • elementwise_affine 正規化された値をスケーリングしてシフトするかどうかです

    引数には PyTorch LayerNorm 実装と同じ名前を使用しようとしました。

    72    def __init__(self, normalized_shape: Union[int, List[int], Size], *,
    73                 eps: float = 1e-5,
    74                 elementwise_affine: bool = True):
    84        super().__init__()

    normalized_shape に変換 torch.Size

    87        if isinstance(normalized_shape, int):
    88            normalized_shape = torch.Size([normalized_shape])
    89        elif isinstance(normalized_shape, list):
    90            normalized_shape = torch.Size(normalized_shape)
    91        assert isinstance(normalized_shape, torch.Size)

    94        self.normalized_shape = normalized_shape
    95        self.eps = eps
    96        self.elementwise_affine = elementwise_affine

    ゲインとバイアスのパラメーターとパラメーターの作成

    98        if self.elementwise_affine:
    99            self.gain = nn.Parameter(torch.ones(normalized_shape))
    100            self.bias = nn.Parameter(torch.zeros(normalized_shape))

    x [*, S[0], S[1], ..., S[n]] 形状のテンソルです。* 次元はいくつでもかまいません。たとえば、NLP タスクでは、次のようになります

    [seq_len, batch_size, features]
    102    def forward(self, x: torch.Tensor):

    形状が合っているか確認するサニティチェック

    110        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

    平均と分散を計算する対象のディメンション

    113        dims = [-(i + 1) for i in range(len(self.normalized_shape))]

    すべての要素の平均、つまり各要素の平均を計算します

    117        mean = x.mean(dim=dims, keepdim=True)

    すべての要素の二乗平均、つまり各要素の平均を計算します

    120        mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)

    全要素の差異

    122        var = mean_x2 - mean ** 2

    ノーマライズ

    125        x_norm = (x - mean) / torch.sqrt(var + self.eps)

    スケールとシフト

    127        if self.elementwise_affine:
    128            x_norm = self.gain * x_norm + self.bias

    131        return x_norm

    簡単なテスト

    134def _test():
    138    from labml.logger import inspect
    139
    140    x = torch.zeros([2, 3, 2, 4])
    141    inspect(x.shape)
    142    ln = LayerNorm(x.shape[2:])
    143
    144    x = ln(x)
    145    inspect(x.shape)
    146    inspect(ln.gain.shape)

    150if __name__ == '__main__':
    151    _test()