これはワッサースタイン GAN の実装です。
元のGAN損失は、実際の分布と生成された分布の間のジェンセン・シャノン(JS)の相違に基づいています。 Wasserstein GANは、これらの分布間のアースムーバーの距離に基づいています
。
は、限界確率が以下のすべての共同分布の集合です。
は特定の関節分布におけるアースムーバー距離 (および確率) です。
したがって、実際の分布と生成された分布の間の任意のジョイント分布における最小地球移動距離に等しくなります。
この論文は、2つの確率分布の差に対するジェンセン・シャノン(JS)ダイバージェンスやその他の測定値がスムーズではないことを示しています。したがって、(パラメータ化された)確率分布の1つで勾配降下を行っても収束しません
。カントロヴィッチとルビンスタインの二元性に基づいて、
1-Lipschitz関数はどこにありますか。
つまり、すべての1-Lipschitz関数の中で最大の差に等しいということです。
-リップシッツ機能の場合、
すべての -Lipschitz 関数を where を次のようにパラメーター化して表現できるとしたら、
ジェネレータで表され、既知の分布からのものであれば、
これで収束させるには、上記の式を最小化するように勾配降下できます。
同様に、境界を保ちながら上昇することで見つけることができます。
境界を維持する1つの方法は、ある範囲内でクリッピングを定義するニューラルネットワーク内のすべてのウェイトをクリップすることです。87import torch.utils.data
88from torch.nn import functional as F
89
90from labml_helpers.module import Module
93class DiscriminatorLoss(Module):
104 def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
RelUSを使って損失をクリッピングしてレンジをキープします。
115 return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()
118class GeneratorLoss(Module):
f_fake
は 130 def forward(self, f_fake: torch.Tensor):
134 return -f_fake.mean()