重量標準化

これは、論文「バッチチャネル正規化と体重標準化によるマイクロバッチトレーニング」から引用した重み標準化をPyTorchで実装したものです。また、バッチチャネル正規化の注釈付き実装もあります。

バッチ正規化により、損失の状況がスムーズになり排除特異点が回避されます。除去特異点とは、ネットワークのノードが役に立たなくなることです(たとえば、常に 0 を返す ReLU など)

ただし、バッチサイズが小さすぎる場合、バッチ正規化はうまく機能しません。これは、デバイスのメモリ制限のために大規模なネットワークをトレーニングするときに発生します。この論文では、より良い代替手段として、バッチチャネル正規化による重み標準化を紹介しています

重量標準化:1.グラデーション 2 を正規化します。風景を滑らかにします (リップシッツ定数を減らします) 3.排除特異点を回避

リップシッツ定数は、関数の 2 点間の最大勾配です。つまり、はどこを満たす最小値のリプシッツ定数です

除去特異点は、出力の統計が入力と同様に保たれるため、回避されます。したがって、入力が正規分布している限り、出力は正常に近いままです。これにより、ノードの出力がアクティベーション関数のアクティブ範囲を常に超えることがなくなります (たとえば、ReLU の場合は常に負の入力になります

)。

証拠については論文を参照してください。

重み標準化を使用して CIFAR-10 データを分類する VGG ネットワークをトレーニングするためのトレーニングコードを次に示します。これは、重みが標準化された2Dコンボリューション層を使用します

Open In Colab

48import torch

重量標準化

どこ、

2Dコンボリューション層の場合、は出力チャネル数 () で、入力チャネル数にカーネルサイズ () を掛けたものです。

51def weight_standardization(weight: torch.Tensor, eps: float):

Get およびカーネルシェイプ

70    c_out, c_in, *kernel_shape = weight.shape

形状を次の形式に変更

72    weight = weight.view(c_out, -1)

計算

79    var, mean = torch.var_mean(weight, dim=1, keepdim=True)

ノーマライズ

82    weight = (weight - mean) / (torch.sqrt(var + eps))

元の形状に戻して戻る

84    return weight.view(c_out, c_in, *kernel_shape)