これは、ヴァッサースタインGANの改良型トレーニングの実装です。
WGANは、ディスクリミネーター・ネットワークにリップシッツ制約を適用するためにウェイトをクリッピングすることを提案している(評論家)。これに加えて、L2 ノルムクリッピング、ウェイト正規化、L1、L2 ウェイト減衰などの他のウェイト制約には問題があります
。1。ディスクリミネーターの容量制限 2.グラデーションが爆発したり消えたりする (バッチ正規化なし)
論文「Wasserstein GANのトレーニングの改善」は、勾配ペナルティであるリップシッツ制約を改善するより良い方法を提案しています。
ペナルティウェイトはどこで
つまり、勾配のノルムを近くに保つようにしています。
この実装では設定しました。
46import torch
47import torch.autograd
48
49from labml_helpers.module import Module
52class GradientPenalty(Module):
57 def forward(self, x: torch.Tensor, f: torch.Tensor):
バッチサイズを取得
67 batch_size = x.shape[0]
を基準とした勾配を計算します。grad_outputs
の勾配を求めているので1に設定し、この損失の重みに対する勾配を計算する必要があるため、グラフを作成して保持する必要があります
73 gradients, *_ = torch.autograd.grad(outputs=f,
74 inputs=x,
75 grad_outputs=f.new_ones(f.shape),
76 create_graph=True)
グラデーションの形を変えてノルムを計算しよう
79 gradients = gradients.reshape(batch_size, -1)
ノルムの計算
81 norm = gradients.norm(2, dim=-1)
損失を返す
83 return torch.mean((norm - 1) ** 2)