グループ正規化

これはグループ正規化論文のPyTorch実装です

バッチ正規化はバッチ全体で正規化されるため、バッチサイズが十分大きい場合はうまく機能しますが、小さなバッチサイズには適していません。デバイスのメモリ容量により、バッチサイズの大きい大規模モデルのトレーニングは不可能です。

本稿では、一連の特徴をグループとしてまとめて正規化するグループ正規化について紹介します。これは、SIFTやHOGなどの古典的特徴はグループごとの特徴であるという観察に基づいています。この論文では、フィーチャチャネルをグループに分割し、各グループ内のすべてのチャネルを個別に正規化することを提案しています

フォーミュレーション

すべての正規化層は、次の計算で定義できます。

ここで、はバッチを表すテンソルで、は単一値のインデックスです。たとえば、 2D画像がバッチ、フィーチャチャンネル、垂直座標、水平座標内の画像をインデックスするための4次元ベクトルである場合。平均と標準偏差です。

は、インデックスの平均と標準偏差を計算するインデックスのセットです。セットのサイズはすべて同じです

の定義は、バッチ正規化、レイヤー正規化インスタンス正規化では異なります

バッチ正規化

同じフィーチャチャネルを共有する値はまとめて正規化されます。

レイヤー正規化

バッチ内の同じサンプルの値はまとめて正規化されます。

インスタンス正規化

同じサンプルと同じ特徴チャンネルの値が一緒に正規化されます。

グループ正規化

ここで、はグループ数、はチャネル数です。

グループ正規化は、同じサンプルと同じチャネルグループの値をまとめて正規化します。

インスタンスの正規化を使用する CIFAR 10 分類モデルを次に示します

Open In Colab

84import torch
85from torch import nn
86
87from labml_helpers.module import Module

グループ正規化レイヤー

90class GroupNorm(Module):
  • groups フィーチャが分割されているグループの数です
  • channels は入力内の特徴の数です
  • eps 数値の安定性のために使用されます
  • affine 正規化された値をスケーリングしてシフトするかどうかです
95    def __init__(self, groups: int, channels: int, *,
96                 eps: float = 1e-5, affine: bool = True):
103        super().__init__()
104
105        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
106        self.groups = groups
107        self.channels = channels
108
109        self.eps = eps
110        self.affine = affine

スケールとシフトのパラメータとパラメータの作成

112        if self.affine:
113            self.scale = nn.Parameter(torch.ones(channels))
114            self.shift = nn.Parameter(torch.zeros(channels))

x [batch_size, channels, *] 形状のテンソルです。* 任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります

[batch_size, channels, height, width]
116    def forward(self, x: torch.Tensor):

元の形を保つ

124        x_shape = x.shape

バッチサイズを取得

126        batch_size = x_shape[0]

機能の数が同じであることを確認するためのサニティチェック

128        assert self.channels == x.shape[1]

形を変えて [batch_size, groups, n]

131        x = x.view(batch_size, self.groups, -1)

最後の次元の平均、つまり各サンプルとチャネルグループの平均を計算します

135        mean = x.mean(dim=[-1], keepdim=True)

最後の次元の二乗平均、つまり各サンプルとチャネルグループの平均を計算します

138        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

各サンプルと機能グループの差異

141        var = mean_x2 - mean ** 2

ノーマライズ

146        x_norm = (x - mean) / torch.sqrt(var + self.eps)

チャンネル単位のスケーリングとシフト

150        if self.affine:
151            x_norm = x_norm.view(batch_size, self.channels, -1)
152            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

元の形に戻して戻す

155        return x_norm.view(x_shape)

簡単なテスト

158def _test():
162    from labml.logger import inspect
163
164    x = torch.zeros([2, 6, 2, 4])
165    inspect(x.shape)
166    bn = GroupNorm(2, 6)
167
168    x = bn(x)
169    inspect(x.shape)

173if __name__ == '__main__':
174    _test()