これは、論文「バッチチャネル正規化と重み標準化によるマイクロバッチトレーニング」にあるバッチチャネル正規化のPyTorch実装です。また、重量標準化の注釈付き実装もあります。
バッチチャネル正規化は、バッチ正規化の後にチャネル正規化を行います (グループ正規化と同様)。バッチサイズが小さい場合は、バッチ正規化に実行平均と分散が使用されます
。重み標準化を使用して CIFAR-10 データを分類する VGG ネットワークをトレーニングするためのトレーニングコードを次に示します。
25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm
これは最初にバッチ正規化を行います。通常のバッチノルムか、推定平均と分散 (複数のバッチにわたる指数関数的平均/分散) を含むバッチノルムのどちらかです。次に、チャネル正規化が実行されました
。32class BatchChannelNorm(Module):
channels
は入力内の特徴の数ですgroups
フィーチャが分割されているグループの数ですeps
数値の安定性のために使用されます momentum
指数移動平均を取るときの勢いですestimate
バッチノルムに実行平均と分散を使用するかどうかです42 def __init__(self, channels: int, groups: int,
43 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51 super().__init__()
推定バッチノルムまたは通常のバッチノルムを使用してください。
54 if estimate:
55 self.batch_norm = EstimatedBatchNorm(channels,
56 eps=eps, momentum=momentum)
57 else:
58 self.batch_norm = BatchNorm(channels,
59 eps=eps, momentum=momentum)
チャンネル正規化
62 self.channel_norm = ChannelNorm(channels, groups, eps)
64 def forward(self, x):
65 x = self.batch_norm(x)
66 return self.channel_norm(x)
69class EstimatedBatchNorm(Module):
channels
は入力内の特徴の数ですeps
数値の安定性のために使用されます momentum
指数移動平均を取るときの勢いですestimate
バッチノルムに実行平均と分散を使用するかどうかです90 def __init__(self, channels: int,
91 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98 super().__init__()
99
100 self.eps = eps
101 self.momentum = momentum
102 self.affine = affine
103 self.channels = channels
チャンネル単位の変換パラメーター
106 if self.affine:
107 self.scale = nn.Parameter(torch.ones(channels))
108 self.shift = nn.Parameter(torch.zeros(channels))
およびのテンソル
111 self.register_buffer('exp_mean', torch.zeros(channels))
112 self.register_buffer('exp_var', torch.ones(channels))
x
[batch_size, channels, *]
形状のテンソルです。*
任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります
[batch_size, channels, height, width]
114 def forward(self, x: torch.Tensor):
古い形を保つ
122 x_shape = x.shape
バッチサイズを取得
124 batch_size = x_shape[0]
機能の数が正しいことを確認するためのサニティチェック
127 assert self.channels == x.shape[1]
形を変えて [batch_size, channels, n]
130 x = x.view(batch_size, self.channels, -1)
アップデートおよびトレーニングモードのみ
133 if self.training:
およびを介したバックプロパゲーションなし
135 with torch.no_grad():
最初の次元と最後の次元の平均を計算します。
138 mean = x.mean(dim=[0, 2])
最初の次元と最後の次元の二乗平均を計算します。
141 mean_x2 = (x ** 2).mean(dim=[0, 2])
各機能の差異
144 var = mean_x2 - mean ** 2
152 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
ノーマライズ
157 x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
スケールとシフト
162 if self.affine:
163 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
元の形に戻して戻す
166 return x_norm.view(x_shape)
169class ChannelNorm(Module):
groups
フィーチャが分割されているグループの数ですchannels
は入力内の特徴の数ですeps
数値の安定性のために使用されます affine
正規化された値をスケーリングしてシフトするかどうかです176 def __init__(self, channels, groups,
177 eps: float = 1e-5, affine: bool = True):
184 super().__init__()
185 self.channels = channels
186 self.groups = groups
187 self.eps = eps
188 self.affine = affine
193 if self.affine:
194 self.scale = nn.Parameter(torch.ones(groups))
195 self.shift = nn.Parameter(torch.zeros(groups))
x
[batch_size, channels, *]
形状のテンソルです。*
任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります
[batch_size, channels, height, width]
197 def forward(self, x: torch.Tensor):
元の形を保つ
206 x_shape = x.shape
バッチサイズを取得
208 batch_size = x_shape[0]
機能の数が同じであることを確認するためのサニティチェック
210 assert self.channels == x.shape[1]
形を変えて [batch_size, groups, n]
213 x = x.view(batch_size, self.groups, -1)
最後の次元の平均、つまり各サンプルとチャネルグループの平均を計算します
217 mean = x.mean(dim=[-1], keepdim=True)
最後の次元の二乗平均、つまり各サンプルとチャネルグループの平均を計算します
220 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
各サンプルと機能グループの差異
223 var = mean_x2 - mean ** 2
ノーマライズ
228 x_norm = (x - mean) / torch.sqrt(var + self.eps)
グループごとのスケーリングとシフト
232 if self.affine:
233 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
元の形に戻して戻す
236 return x_norm.view(x_shape)