インスタンス正規化

これは PyTorch の「インスタンス正規化:高速スタイル化に欠けている要素」の実装です。

スタイル転送を改善するためにインスタンスの正規化が導入されました。これは、スタイル設定はコンテンツ画像のコントラストに依存すべきではないという観察に基づいています。「コントラスト正規化」とは

where は、画像インデックス、フィーチャチャネルおよび空間位置を含む画像のバッチです。

畳み込みネットワークでは「コントラスト正規化」を学習するのは難しいので、本稿ではそれを行うインスタンス正規化を紹介します。

インスタンスの正規化を使用する CIFAR 10 分類モデルを次に示します

29import torch
30from torch import nn
31
32from labml_helpers.module import Module

インスタンス正規化レイヤー

インスタンス正規化レイヤーは、次のように入力を正規化します。

入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。およびを使用したアフィン変換はオプションです

35class InstanceNorm(Module):
  • channels は入力内の特徴の数です
  • eps 数値の安定性のために使用されます
  • affine 正規化された値をスケーリングしてシフトするかどうかです
51    def __init__(self, channels: int, *,
52                 eps: float = 1e-5, affine: bool = True):
58        super().__init__()
59
60        self.channels = channels
61
62        self.eps = eps
63        self.affine = affine

スケールとシフトのパラメータとパラメータの作成

65        if self.affine:
66            self.scale = nn.Parameter(torch.ones(channels))
67            self.shift = nn.Parameter(torch.zeros(channels))

x [batch_size, channels, *] 形状のテンソルです。* 任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります

[batch_size, channels, height, width]
69    def forward(self, x: torch.Tensor):

元の形を保つ

77        x_shape = x.shape

バッチサイズを取得

79        batch_size = x_shape[0]

機能の数が同じであることを確認するためのサニティチェック

81        assert self.channels == x.shape[1]

形を変えて [batch_size, channels, n]

84        x = x.view(batch_size, self.channels, -1)

最後の次元の平均、つまり各特徴の平均を計算します

88        mean = x.mean(dim=[-1], keepdim=True)

最初と最後の次元の二乗平均、つまり各特徴の平均を計算します。

91        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

各機能の差異

93        var = mean_x2 - mean ** 2

ノーマライズ

96        x_norm = (x - mean) / torch.sqrt(var + self.eps)
97        x_norm = x_norm.view(batch_size, self.channels, -1)

スケールとシフト

100        if self.affine:
101            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

元の形に戻して戻す

104        return x_norm.view(x_shape)

簡単なテスト

107def _test():
111    from labml.logger import inspect
112
113    x = torch.zeros([2, 6, 2, 4])
114    inspect(x.shape)
115    bn = InstanceNorm(6)
116
117    x = bn(x)
118    inspect(x.shape)

122if __name__ == '__main__':
123    _test()