これは PyTorch の「インスタンス正規化:高速スタイル化に欠けている要素」の実装です。
スタイル転送を改善するためにインスタンスの正規化が導入されました。これは、スタイル設定はコンテンツ画像のコントラストに依存すべきではないという観察に基づいています。「コントラスト正規化」とは
where は、画像インデックス、フィーチャチャネル、および空間位置を含む画像のバッチです。
畳み込みネットワークでは「コントラスト正規化」を学習するのは難しいので、本稿ではそれを行うインスタンス正規化を紹介します。
インスタンスの正規化を使用する CIFAR 10 分類モデルを次に示します。
29import torch
30from torch import nn
31
32from labml_helpers.module import Module
インスタンス正規化レイヤーは、次のように入力を正規化します。
入力がイメージ表現のバッチの場合、はバッチサイズ、はチャネル数、は高さ、は幅です。と。およびを使用したアフィン変換はオプションです
。
35class InstanceNorm(Module):
channels
は入力内の特徴の数ですeps
数値の安定性のために使用されます affine
正規化された値をスケーリングしてシフトするかどうかです51 def __init__(self, channels: int, *,
52 eps: float = 1e-5, affine: bool = True):
58 super().__init__()
59
60 self.channels = channels
61
62 self.eps = eps
63 self.affine = affine
スケールとシフトのパラメータとパラメータの作成
65 if self.affine:
66 self.scale = nn.Parameter(torch.ones(channels))
67 self.shift = nn.Parameter(torch.zeros(channels))
x
[batch_size, channels, *]
形状のテンソルです。*
任意の数 (0 の場合もあります) の次元を示します。たとえば、画像 (2D) のコンボリューションでは、次のようになります
[batch_size, channels, height, width]
69 def forward(self, x: torch.Tensor):
元の形を保つ
77 x_shape = x.shape
バッチサイズを取得
79 batch_size = x_shape[0]
機能の数が同じであることを確認するためのサニティチェック
81 assert self.channels == x.shape[1]
形を変えて [batch_size, channels, n]
84 x = x.view(batch_size, self.channels, -1)
最後の次元の平均、つまり各特徴の平均を計算します
88 mean = x.mean(dim=[-1], keepdim=True)
最初と最後の次元の二乗平均、つまり各特徴の平均を計算します。
91 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
各機能の差異
93 var = mean_x2 - mean ** 2
ノーマライズ
96 x_norm = (x - mean) / torch.sqrt(var + self.eps)
97 x_norm = x_norm.view(batch_size, self.channels, -1)
スケールとシフト
100 if self.affine:
101 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
元の形に戻して戻す
104 return x_norm.view(x_shape)
簡単なテスト
107def _test():
111 from labml.logger import inspect
112
113 x = torch.zeros([2, 6, 2, 4])
114 inspect(x.shape)
115 bn = InstanceNorm(6)
116
117 x = bn(x)
118 inspect(x.shape)
122if __name__ == '__main__':
123 _test()