これは、「バッチ正規化:内部共変量シフトを減らすことによるディープネットワークトレーニングの高速化」という論文からバッチ正規化をPyTorchで実装したものです。
この論文では、内部共変量シフトを、トレーニング中のネットワークパラメーターの変化によるネットワークアクティベーションの分布の変化として定義しています。たとえば、との 2 つのレイヤーがあるとします。トレーニングの開始時に、アウトプット(へのインプット)が配布される可能性があります。その後、いくつかのトレーニング手順を実行すると、に移動する可能性があります。これは内部共変量シフトです
。内部共変量シフトは、後の層(上の例)がこのシフトした分布に適応しなければならないため、トレーニング速度に悪影響を及ぼします。
分布を安定させることにより、バッチ正規化は内部共変量シフトを最小限に抑えます。
ホワイトニングはトレーニングのスピードとコンバージェンスを向上させることが知られています。ホワイトニングとは、入力を平均がゼロ、単位分散、無相関になるように線形に変換することです
。事前に計算された(分離された)平均と分散を使用して勾配計算の外で正規化することはできません。例えば。(分散は無視)、ここで、 and はトレーニング済みのバイアスで、外部勾配計算 (事前に計算された定数) です
。には影響しないことに注意してください。したがって、トレーニングを更新するたびに増加または減少し、無期限に成長し続けます。この論文は、同様の爆発にはばらつきがあると述べています
。ホワイトニングは、相関をなくす必要があり、勾配がホワイトニングの計算全体を通る必要があるため、計算量が多くなります。
この論文では、バッチ正規化と呼ばれる簡略版を紹介しています。1 つ目の簡略化は、各特徴量を独立して平均が 0、単位分散になるように正規化することです。ここで、は -次元の入力です
。2 つ目の簡略化は、データセット全体の平均と分散を計算するのではなく、ミニバッチからの平均と分散の推定値を正規化に使用することです。
各特徴量を平均ゼロと単位分散に正規化すると、レイヤーが表現できる内容に影響する可能性があります。例示しているように、シグモイドへの入力が正規化されると、そのほとんどはシグモイドが線形である範囲内になります。これを解決するために、各機能のスケーリングとシフトを学習済みの 2 つのパラメーターとで調整します。ここで、はバッチ正規化層の出力です
。線形変換のような線形変換の後にバッチ正規化を適用すると、正規化によりバイアスパラメータがキャンセルされることに注意してください。そのため、バッチ正規化の直前に線形変換のバイアスパラメータを省略することができ、また省略すべきです
。また、バッチ正規化では逆伝播が重みのスケールに対して不変になり、経験的にジェネラライズが改善されるため、正則化効果もあります。
正規化を実行するには、とを知る必要があります。そのため、推論時には、データセットの全体 (または一部) を調べて平均と分散を求めるか、トレーニング中に計算された推定値を使用する必要があります。通常は、トレーニング段階で平均と分散の指数移動平均を計算し、それを推論に使用します
。以下は、MNIST データセットのバッチ正規化を使用する CNN 分類器をトレーニングするためのトレーニングコードとノートブックです。
97import torch
98from torch import nn
99
100from labml_helpers.module import Module
バッチ正規化層は、次のように入力を正規化します。
入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。と。
入力が埋め込みのバッチの場合、はバッチサイズ、はフィーチャの数です。と。
入力がシーケンス埋め込みのバッチの場合、はバッチサイズ、はフィーチャ数、はシーケンスの長さです。と。
103class BatchNorm(Module):
channels
は入力内の特徴の数ですeps
数値の安定性のために使用されます momentum
指数移動平均を取るときの勢いですaffine
正規化された値をスケーリングしてシフトするかどうかですtrack_running_stats
移動平均を計算するか、平均と分散を計算するかです引数には PyTorch BatchNorm
実装と同じ名前を使用しようとしました。
131 def __init__(self, channels: int, *,
132 eps: float = 1e-5, momentum: float = 0.1,
133 affine: bool = True, track_running_stats: bool = True):
143 super().__init__()
144
145 self.channels = channels
146
147 self.eps = eps
148 self.momentum = momentum
149 self.affine = affine
150 self.track_running_stats = track_running_stats
スケールとシフトのパラメータとパラメータの作成
152 if self.affine:
153 self.scale = nn.Parameter(torch.ones(channels))
154 self.shift = nn.Parameter(torch.zeros(channels))
平均と分散の指数移動平均を格納するバッファーの作成
157 if self.track_running_stats:
158 self.register_buffer('exp_mean', torch.zeros(channels))
159 self.register_buffer('exp_var', torch.ones(channels))
x
[batch_size, channels, *]
形状のテンソルです。*
任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります
[batch_size, channels, height, width]
161 def forward(self, x: torch.Tensor):
元の形を保つ
169 x_shape = x.shape
バッチサイズを取得
171 batch_size = x_shape[0]
機能の数が同じであることを確認するためのサニティチェック
173 assert self.channels == x.shape[1]
形を変えて [batch_size, channels, n]
176 x = x.view(batch_size, self.channels, -1)
トレーニングモードの場合、または指数移動平均を追跡していない場合は、ミニバッチの平均と分散を計算します。
180 if self.training or not self.track_running_stats:
最初のディメンションと最後のディメンションの平均、つまり各フィーチャの平均を計算します
183 mean = x.mean(dim=[0, 2])
最初と最後の次元の二乗平均、つまり各特徴の平均を計算します。
186 mean_x2 = (x ** 2).mean(dim=[0, 2])
各機能の差異
188 var = mean_x2 - mean ** 2
指数移動平均の更新
191 if self.training and self.track_running_stats:
192 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
193 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
指数移動平均を推定値として使用
195 else:
196 mean = self.exp_mean
197 var = self.exp_var
ノーマライズ
200 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
スケールとシフト
202 if self.affine:
203 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
元の形に戻して戻す
206 return x_norm.view(x_shape)
簡単なテスト
209def _test():
213 from labml.logger import inspect
214
215 x = torch.zeros([2, 3, 2, 4])
216 inspect(x.shape)
217 bn = BatchNorm(3)
218
219 x = bn(x)
220 inspect(x.shape)
221 inspect(bn.exp_var.shape)
225if __name__ == '__main__':
226 _test()