这是 Wasserstein GAN 的实现。
最初的GAN损耗基于实际分布和生成的分布之间的Jensen-Shannon(JS)差异。Wasserstein GAN 基于这些分布之间的 Earth Mover 距离。
是所有联合分布的集合,其边际概率为。
是给定关节分布的地球移动器距离(也是概率)。
因此,等于实际分布和生成的分布之间任何关节分布的最小地球移动器距离。
本文表明,Jensen-Shannon(JS)背离和其他衡量两个概率分布之间差异的度量并不平滑。因此,如果我们对其中一个概率分布(参数化)进行梯度下降,它将不会收敛。
基于坎托罗维奇-鲁宾斯坦二元性,
所有的 1-Lipschitz 函数都在哪里。
也就是说,它等于所有 1-Lipschitz 函数之间的最大差异。
对于-Lipschitz 函数,
如果所有-Lipschitz 函数都可以表示为参数化了哪里,
如果由生成器表示并且来自已知分布,
现在为了收敛,我们可以通过梯度下降来最小化上述公式。
同样,我们可以通过上升来找到,同时保持界限。保持界限的一种方法是裁剪神经网络中定义范围内的裁剪的所有权重。
以下是在一个简单的 MNIST 生成实验中尝试此操作的代码。
87import torch.utils.data
88from torch.nn import functional as F
89
90from labml_helpers.module import Module
93class DiscriminatorLoss(Module):
104 def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
我们使用 RELUs 来削减损失以保持射程。
115 return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()
118class GeneratorLoss(Module):
f_fake
是130 def forward(self, f_fake: torch.Tensor):
134 return -f_fake.mean()