これは、論文「バッチチャネル正規化と体重標準化によるマイクロバッチトレーニング」から引用した重み標準化をPyTorchで実装したものです。また、バッチチャネル正規化の注釈付き実装もあります。
バッチ正規化により、損失の状況がスムーズになり、排除特異点が回避されます。除去特異点とは、ネットワークのノードが役に立たなくなることです(たとえば、常に 0 を返す ReLU など)
。ただし、バッチサイズが小さすぎる場合、バッチ正規化はうまく機能しません。これは、デバイスのメモリ制限のために大規模なネットワークをトレーニングするときに発生します。この論文では、より良い代替手段として、バッチチャネル正規化による重み標準化を紹介しています
。重量標準化:1.グラデーション 2 を正規化します。風景を滑らかにします (リップシッツ定数を減らします) 3.排除特異点を回避
リップシッツ定数は、関数の 2 点間の最大勾配です。つまり、はどこを満たす最小値のリプシッツ定数です
。除去特異点は、出力の統計が入力と同様に保たれるため、回避されます。したがって、入力が正規分布している限り、出力は正常に近いままです。これにより、ノードの出力がアクティベーション関数のアクティブ範囲を常に超えることがなくなります (たとえば、ReLU の場合は常に負の入力になります
)。重み標準化を使用して CIFAR-10 データを分類する VGG ネットワークをトレーニングするためのトレーニングコードを次に示します。これは、重みが標準化された2Dコンボリューション層を使用します
。48import torch
51def weight_standardization(weight: torch.Tensor, eps: float):
Get 、およびカーネルシェイプ
70 c_out, c_in, *kernel_shape = weight.shape
形状を次の形式に変更
72 weight = weight.view(c_out, -1)
79 var, mean = torch.var_mean(weight, dim=1, keepdim=True)
ノーマライズ
82 weight = (weight - mean) / (torch.sqrt(var + eps))
元の形状に戻して戻る
84 return weight.view(c_out, c_in, *kernel_shape)