これはレイヤー正規化の PyTorch 実装です。
レイヤー正規化は、より幅広い設定に適用できる、より単純な正規化方法です。層の正規化により、入力は特徴全体で平均がゼロで単位分散がなくなるように変換されます。バッチ正規化では、各要素のゼロ平均と単位分散が固定されることに注意してください。レイヤーの正規化は、すべての要素のバッチごとに正規化を行います
。レイヤー正規化は通常、NLP タスクに使用されます。
35from typing import Union, List
36
37import torch
38from torch import nn, Size
39
40from labml_helpers.module import Module
層の正規化は、入力を次のように正規化します。
入力が埋め込みのバッチの場合、はバッチサイズ、はフィーチャの数です。と。
入力が埋め込みシーケンスのバッチである場合、はバッチサイズ、はチャネル数、はシーケンスの長さです。と。
入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。これはあまり使われていないシナリオです。と。
43class LayerNorm(Module):
normalized_shape
要素の形状です (バッチは除く)。その場合、入力は次のようになります。eps
数値の安定性のために使用されます elementwise_affine
正規化された値をスケーリングしてシフトするかどうかです引数には PyTorch LayerNorm
実装と同じ名前を使用しようとしました。
72 def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73 eps: float = 1e-5,
74 elementwise_affine: bool = True):
84 super().__init__()
normalized_shape
に変換 torch.Size
87 if isinstance(normalized_shape, int):
88 normalized_shape = torch.Size([normalized_shape])
89 elif isinstance(normalized_shape, list):
90 normalized_shape = torch.Size(normalized_shape)
91 assert isinstance(normalized_shape, torch.Size)
94 self.normalized_shape = normalized_shape
95 self.eps = eps
96 self.elementwise_affine = elementwise_affine
ゲインとバイアスのパラメーターとパラメーターの作成
98 if self.elementwise_affine:
99 self.gain = nn.Parameter(torch.ones(normalized_shape))
100 self.bias = nn.Parameter(torch.zeros(normalized_shape))
x
[*, S[0], S[1], ..., S[n]]
形状のテンソルです。*
次元はいくつでもかまいません。たとえば、NLP タスクでは、次のようになります
[seq_len, batch_size, features]
102 def forward(self, x: torch.Tensor):
形状が合っているか確認するサニティチェック
110 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
平均と分散を計算する対象のディメンション
113 dims = [-(i + 1) for i in range(len(self.normalized_shape))]
すべての要素の平均、つまり各要素の平均を計算します
117 mean = x.mean(dim=dims, keepdim=True)
すべての要素の二乗平均、つまり各要素の平均を計算します
120 mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
全要素の差異
122 var = mean_x2 - mean ** 2
ノーマライズ
125 x_norm = (x - mean) / torch.sqrt(var + self.eps)
スケールとシフト
127 if self.elementwise_affine:
128 x_norm = self.gain * x_norm + self.bias
131 return x_norm
簡単なテスト
134def _test():
138 from labml.logger import inspect
139
140 x = torch.zeros([2, 3, 2, 4])
141 inspect(x.shape)
142 ln = LayerNorm(x.shape[2:])
143
144 x = ln(x)
145 inspect(x.shape)
146 inspect(ln.gain.shape)
150if __name__ == '__main__':
151 _test()