Wasserstein GAN (WGAN)

这是 Wasserstein GAN 的实现。

最初的GAN损耗基于实际分布和生成的分布之间的Jensen-Shannon(JS)差异。Wasserstein GAN 基于这些分布之间的 Earth Mover 距离。

是所有联合分布的集合,其边际概率为

是给定关节分布的地球移动器距离(也是概率)。

因此,等于实际分布和生成的分布之间任何关节分布的最小地球移动器距离

本文表明,Jensen-Shannon(JS)背离和其他衡量两个概率分布之间差异的度量并不平滑。因此,如果我们对其中一个概率分布(参数化)进行梯度下降,它将不会收敛。

基于坎托罗维奇-鲁宾斯坦二元性,

所有的 1-Lipschitz 函数都在哪里。

也就是说,它等于所有 1-Lipschitz 函数之间的最大差异。

对于-Lipschitz 函数,

如果所有-Lipschitz 函数都可以表示为参数化了哪里

如果由生成器表示并且来自已知分布

现在为了收敛我们可以通过梯度下降来最小化上述公式。

同样,我们可以通过上升来找到,同时保持界限。保持界限的一种方法是裁剪神经网络中定义范围内的裁剪的所有权重。

以下是在一个简单的 MNIST 生成实验中尝试此操作的代码。

Open In Colab

87import torch.utils.data
88from torch.nn import functional as F
89
90from labml_helpers.module import Module

鉴别器丢失

我们想找到最大化,所以我们最小化,

93class DiscriminatorLoss(Module):
  • f_real
  • f_fake
  • 这将返回带有 and 亏损的 a 元组,稍后会添加这些元组。它们分开存放以进行日志记录。

    104    def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):

    我们使用 RELUs 来削减损失以保持射程。

    115        return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()

    发电机损失

    我们想找到最小化第一个组件是独立的,所以我们最小化,

    118class GeneratorLoss(Module):
    • f_fake
    130    def forward(self, f_fake: torch.Tensor):
    134        return -f_fake.mean()