ワッサースタイン GAN (WGAN)

これはワッサースタイン GAN の実装です。

元のGAN損失は、実際の分布と生成された分布の間のジェンセン・シャノン(JS)の相違に基づいています。 Wasserstein GANは、これらの分布間のアースムーバーの距離に基づいています

は、限界確率が以下のすべての共同分布の集合です。

は特定の関節分布におけるアースムーバー距離 (および確率) です。

したがって実際の分布と生成された分布の間の任意のジョイント分布における最小地球移動距離に等しくなります。

この論文は、2つの確率分布の差に対するジェンセン・シャノン(JS)ダイバージェンスやその他の測定値がスムーズではないことを示しています。したがって、(パラメータ化された)確率分布の1つで勾配降下を行っても収束しません

カントロヴィッチとルビンスタインの二元性に基づいて、

1-Lipschitz関数はどこにありますか。

つまり、すべての1-Lipschitz関数の中で最大の差に等しいということです。

-リップシッツ機能の場合、

すべての -Lipschitz 関数を where を次のようにパラメーター化して表現できるとしたら、

ジェネレータで表され、既知の分布からのものであれば

これで収束させるには、上記の式を最小化するように勾配降下できます。

同様に、境界を保ちながら上昇することで見つけることができます。

境界を維持する1つの方法は、ある範囲内でクリッピングを定義するニューラルネットワーク内のすべてのウェイトをクリップすることです。

これを簡単な MNIST 生成実験で試すコードを次に示します

Open In Colab

87import torch.utils.data
88from torch.nn import functional as F
89
90from labml_helpers.module import Module

ディスクリミネーターロス

最大化したいから最小化する

93class DiscriminatorLoss(Module):
  • f_real
  • f_fake
  • これにより、とが失われたタプルが返されます。このタプルは後で追加されます。これらはロギング用に別々に保管されます。

    104    def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):

    RelUSを使って損失をクリッピングしてレンジをキープします。

    115        return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()

    発電機損失

    最小化したい最初のコンポーネントは独立しているので、最小化します。

    118class GeneratorLoss(Module):
    • f_fake
    130    def forward(self, f_fake: torch.Tensor):
    134        return -f_fake.mean()