バッチ正規化

これは、「バッチ正規化:内部共変量シフトを減らすことによるディープネットワークトレーニングの高速化」という論文からバッチ正規化をPyTorchで実装したものです

内部共変量シフト

この論文では、内部共変量シフトを、トレーニング中のネットワークパラメーターの変化によるネットワークアクティベーションの分布の変化として定義しています。たとえば、との 2 つのレイヤーがあるとします。トレーニングの開始時に、アウトプット(へのインプット)が配布される可能性があります。その後、いくつかのトレーニング手順を実行すると、に移動する可能性がありますこれは内部共変量シフトです

内部共変量シフトは、後の層(上の例)がこのシフトした分布に適応しなければならないため、トレーニング速度に悪影響を及ぼします。

分布を安定させることにより、バッチ正規化は内部共変量シフトを最小限に抑えます。

ノーマライゼーション

ホワイトニングはトレーニングのスピードとコンバージェンスを向上させることが知られています。ホワイトニングとは、入力を平均がゼロ、単位分散、無相関になるように線形に変換することです

外部勾配計算の正規化は機能しません

事前に計算された(分離された)平均と分散を使用して勾配計算の外で正規化することはできません。例えば。(分散は無視)、ここで、 and はトレーニング済みのバイアスで、外部勾配計算 (事前に計算された定数) です

には影響しないことに注意してください。したがって、トレーニングを更新するたびに増加または減少し、無期限に成長し続けます。この論文は、同様の爆発にはばらつきがあると述べています

バッチ正規化

ホワイトニングは、相関をなくす必要があり、勾配がホワイトニングの計算全体を通る必要があるため、計算量が多くなります。

この論文では、バッチ正規化と呼ばれる簡略版を紹介しています。1 つ目の簡略化は、各特徴量を独立して平均が 0、単位分散になるように正規化することです。ここで、は -次元の入力です

2 つ目の簡略化は、データセット全体の平均と分散を計算するのではなく、ミニバッチからの平均と分散の推定値を正規化に使用することです。

各特徴量を平均ゼロと単位分散に正規化すると、レイヤーが表現できる内容に影響する可能性があります。例示しているように、シグモイドへの入力が正規化されると、そのほとんどはシグモイドが線形である範囲内になります。これを解決するために、各機能のスケーリングとシフトを学習済みの 2 つのパラメーターとで調整します。ここで、はバッチ正規化層の出力です

線形変換のような線形変換の後にバッチ正規化を適用すると、正規化によりバイアスパラメータがキャンセルされることに注意してください。そのため、バッチ正規化の直前に線形変換のバイアスパラメータを省略することができ、また省略すべきです

また、バッチ正規化では逆伝播が重みのスケールに対して不変になり、経験的にジェネラライズが改善されるため、正則化効果もあります。

推論

正規化を実行するには、とを知る必要があります。そのため、推論時には、データセットの全体 (または一部) を調べて平均と分散を求めるか、トレーニング中に計算された推定値を使用する必要があります。通常は、トレーニング段階で平均と分散の指数移動平均を計算し、それを推論に使用します

以下は、MNIST データセットのバッチ正規化を使用する CNN 分類器をトレーニングするためのトレーニングコードとノートブックです

Open In Colab

97import torch
98from torch import nn
99
100from labml_helpers.module import Module

バッチ正規化レイヤー

バッチ正規化層は、次のように入力を正規化します。

入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。

入力が埋め込みのバッチの場合、はバッチサイズ、はフィーチャの数です。

入力がシーケンス埋め込みのバッチの場合、はバッチサイズ、はフィーチャ数、はシーケンスの長さです。

103class BatchNorm(Module):
  • channels は入力内の特徴の数です
  • eps 数値の安定性のために使用されます
  • momentum 指数移動平均を取るときの勢いです
  • affine 正規化された値をスケーリングしてシフトするかどうかです
  • track_running_stats 移動平均を計算するか、平均と分散を計算するかです

引数には PyTorch BatchNorm 実装と同じ名前を使用しようとしました。

131    def __init__(self, channels: int, *,
132                 eps: float = 1e-5, momentum: float = 0.1,
133                 affine: bool = True, track_running_stats: bool = True):
143        super().__init__()
144
145        self.channels = channels
146
147        self.eps = eps
148        self.momentum = momentum
149        self.affine = affine
150        self.track_running_stats = track_running_stats

スケールとシフトのパラメータとパラメータの作成

152        if self.affine:
153            self.scale = nn.Parameter(torch.ones(channels))
154            self.shift = nn.Parameter(torch.zeros(channels))

平均と分散の指数移動平均を格納するバッファーの作成

157        if self.track_running_stats:
158            self.register_buffer('exp_mean', torch.zeros(channels))
159            self.register_buffer('exp_var', torch.ones(channels))

x [batch_size, channels, *] 形状のテンソルです。* 任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります

[batch_size, channels, height, width]
161    def forward(self, x: torch.Tensor):

元の形を保つ

169        x_shape = x.shape

バッチサイズを取得

171        batch_size = x_shape[0]

機能の数が同じであることを確認するためのサニティチェック

173        assert self.channels == x.shape[1]

形を変えて [batch_size, channels, n]

176        x = x.view(batch_size, self.channels, -1)

トレーニングモードの場合、または指数移動平均を追跡していない場合は、ミニバッチの平均と分散を計算します。

180        if self.training or not self.track_running_stats:

最初のディメンションと最後のディメンションの平均、つまり各フィーチャの平均を計算します

183            mean = x.mean(dim=[0, 2])

最初と最後の次元の二乗平均、つまり各特徴の平均を計算します。

186            mean_x2 = (x ** 2).mean(dim=[0, 2])

各機能の差異

188            var = mean_x2 - mean ** 2

指数移動平均の更新

191            if self.training and self.track_running_stats:
192                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
193                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

指数移動平均を推定値として使用

195        else:
196            mean = self.exp_mean
197            var = self.exp_var

ノーマライズ

200        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)

スケールとシフト

202        if self.affine:
203            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

元の形に戻して戻す

206        return x_norm.view(x_shape)

簡単なテスト

209def _test():
213    from labml.logger import inspect
214
215    x = torch.zeros([2, 3, 2, 4])
216    inspect(x.shape)
217    bn = BatchNorm(3)
218
219    x = bn(x)
220    inspect(x.shape)
221    inspect(bn.exp_var.shape)

225if __name__ == '__main__':
226    _test()