バッチ正規化はバッチ全体で正規化されるため、バッチサイズが十分大きい場合はうまく機能しますが、小さなバッチサイズには適していません。デバイスのメモリ容量により、バッチサイズの大きい大規模モデルのトレーニングは不可能です。
本稿では、一連の特徴をグループとしてまとめて正規化するグループ正規化について紹介します。これは、SIFTやHOGなどの古典的特徴はグループごとの特徴であるという観察に基づいています。この論文では、フィーチャチャネルをグループに分割し、各グループ内のすべてのチャネルを個別に正規化することを提案しています
。すべての正規化層は、次の計算で定義できます。
ここで、はバッチを表すテンソルで、は単一値のインデックスです。たとえば、 2D画像がバッチ、フィーチャチャンネル、垂直座標、水平座標内の画像をインデックスするための4次元ベクトルである場合。平均と標準偏差です。
は、インデックスの平均と標準偏差を計算するインデックスのセットです。セットのサイズはすべて同じです。
の定義は、バッチ正規化、レイヤー正規化、インスタンス正規化では異なります。
同じフィーチャチャネルを共有する値はまとめて正規化されます。
バッチ内の同じサンプルの値はまとめて正規化されます。
同じサンプルと同じ特徴チャンネルの値が一緒に正規化されます。
ここで、はグループ数、はチャネル数です。
グループ正規化は、同じサンプルと同じチャネルグループの値をまとめて正規化します。
インスタンスの正規化を使用する CIFAR 10 分類モデルを次に示します。
84import torch
85from torch import nn
86
87from labml_helpers.module import Module
90class GroupNorm(Module):
groups
フィーチャが分割されているグループの数ですchannels
は入力内の特徴の数ですeps
数値の安定性のために使用されます affine
正規化された値をスケーリングしてシフトするかどうかです95 def __init__(self, groups: int, channels: int, *,
96 eps: float = 1e-5, affine: bool = True):
103 super().__init__()
104
105 assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
106 self.groups = groups
107 self.channels = channels
108
109 self.eps = eps
110 self.affine = affine
スケールとシフトのパラメータとパラメータの作成
112 if self.affine:
113 self.scale = nn.Parameter(torch.ones(channels))
114 self.shift = nn.Parameter(torch.zeros(channels))
x
[batch_size, channels, *]
形状のテンソルです。*
任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります
[batch_size, channels, height, width]
116 def forward(self, x: torch.Tensor):
元の形を保つ
124 x_shape = x.shape
バッチサイズを取得
126 batch_size = x_shape[0]
機能の数が同じであることを確認するためのサニティチェック
128 assert self.channels == x.shape[1]
形を変えて [batch_size, groups, n]
131 x = x.view(batch_size, self.groups, -1)
最後の次元の平均、つまり各サンプルとチャネルグループの平均を計算します
135 mean = x.mean(dim=[-1], keepdim=True)
最後の次元の二乗平均、つまり各サンプルとチャネルグループの平均を計算します
138 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
各サンプルと機能グループの差異
141 var = mean_x2 - mean ** 2
ノーマライズ
146 x_norm = (x - mean) / torch.sqrt(var + self.eps)
チャンネル単位のスケーリングとシフト
150 if self.affine:
151 x_norm = x_norm.view(batch_size, self.channels, -1)
152 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
元の形に戻して戻す
155 return x_norm.view(x_shape)
簡単なテスト
158def _test():
162 from labml.logger import inspect
163
164 x = torch.zeros([2, 6, 2, 4])
165 inspect(x.shape)
166 bn = GroupNorm(2, 6)
167
168 x = bn(x)
169 inspect(x.shape)
173if __name__ == '__main__':
174 _test()