25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm

バッチチャネル正規化

これは最初にバッチ正規化を行います。通常のバッチノルムか、推定平均と分散 (複数のバッチにわたる指数関数的平均/分散) を含むバッチノルムのどちらかです。次に、チャネル正規化が実行されました

32class BatchChannelNorm(Module):
  • channels は入力内の特徴の数です
  • groups フィーチャが分割されているグループの数です
  • eps 数値の安定性のために使用されます
  • momentum 指数移動平均を取るときの勢いです
  • estimate バッチノルムに実行平均と分散を使用するかどうかです
42    def __init__(self, channels: int, groups: int,
43                 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51        super().__init__()

推定バッチノルムまたは通常のバッチノルムを使用してください。

54        if estimate:
55            self.batch_norm = EstimatedBatchNorm(channels,
56                                                 eps=eps, momentum=momentum)
57        else:
58            self.batch_norm = BatchNorm(channels,
59                                        eps=eps, momentum=momentum)

チャンネル正規化

62        self.channel_norm = ChannelNorm(channels, groups, eps)
64    def forward(self, x):
65        x = self.batch_norm(x)
66        return self.channel_norm(x)

推定バッチ正規化

入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。

どこ、

は移動平均と分散です。指数平均を計算するための運動量です。

69class EstimatedBatchNorm(Module):
  • channels は入力内の特徴の数です
  • eps 数値の安定性のために使用されます
  • momentum 指数移動平均を取るときの勢いです
  • estimate バッチノルムに実行平均と分散を使用するかどうかです
90    def __init__(self, channels: int,
91                 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98        super().__init__()
99
100        self.eps = eps
101        self.momentum = momentum
102        self.affine = affine
103        self.channels = channels

チャンネル単位の変換パラメーター

106        if self.affine:
107            self.scale = nn.Parameter(torch.ones(channels))
108            self.shift = nn.Parameter(torch.zeros(channels))

およびのテンソル

111        self.register_buffer('exp_mean', torch.zeros(channels))
112        self.register_buffer('exp_var', torch.ones(channels))

x [batch_size, channels, *] 形状のテンソルです。* 任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります

[batch_size, channels, height, width]
114    def forward(self, x: torch.Tensor):

古い形を保つ

122        x_shape = x.shape

バッチサイズを取得

124        batch_size = x_shape[0]

機能の数が正しいことを確認するためのサニティチェック

127        assert self.channels == x.shape[1]

形を変えて [batch_size, channels, n]

130        x = x.view(batch_size, self.channels, -1)

アップデートおよびトレーニングモードのみ

133        if self.training:

およびを介したバックプロパゲーションなし

135            with torch.no_grad():

最初の次元と最後の次元の平均を計算します。

138                mean = x.mean(dim=[0, 2])

最初の次元と最後の次元の二乗平均を計算します。

141                mean_x2 = (x ** 2).mean(dim=[0, 2])

各機能の差異

144                var = mean_x2 - mean ** 2

指数移動平均の更新

152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

ノーマライズ

157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)

スケールとシフト

162        if self.affine:
163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

元の形に戻して戻す

166        return x_norm.view(x_shape)

チャンネル正規化

これはグループ正規化に似ていますが、アフィン変換はグループごとに行われます。

169class ChannelNorm(Module):
  • groups フィーチャが分割されているグループの数です
  • channels は入力内の特徴の数です
  • eps 数値の安定性のために使用されます
  • affine 正規化された値をスケーリングしてシフトするかどうかです
176    def __init__(self, channels, groups,
177                 eps: float = 1e-5, affine: bool = True):
184        super().__init__()
185        self.channels = channels
186        self.groups = groups
187        self.eps = eps
188        self.affine = affine

アフィン変換のパラメーター。

これらの変換は、チャネルごとに変換されるグループノルムとは異なり、グループごとに行われることに注意してください。

193        if self.affine:
194            self.scale = nn.Parameter(torch.ones(groups))
195            self.shift = nn.Parameter(torch.zeros(groups))

x [batch_size, channels, *] 形状のテンソルです。* 任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります

[batch_size, channels, height, width]
197    def forward(self, x: torch.Tensor):

元の形を保つ

206        x_shape = x.shape

バッチサイズを取得

208        batch_size = x_shape[0]

機能の数が同じであることを確認するためのサニティチェック

210        assert self.channels == x.shape[1]

形を変えて [batch_size, groups, n]

213        x = x.view(batch_size, self.groups, -1)

最後の次元の平均、つまり各サンプルとチャネルグループの平均を計算します

217        mean = x.mean(dim=[-1], keepdim=True)

最後の次元の二乗平均、つまり各サンプルとチャネルグループの平均を計算します

220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

各サンプルと機能グループの差異

223        var = mean_x2 - mean ** 2

ノーマライズ

228        x_norm = (x - mean) / torch.sqrt(var + self.eps)

グループごとのスケーリングとシフト

232        if self.affine:
233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

元の形に戻して戻す

236        return x_norm.view(x_shape)