Wasserstein GAN (WGAN-GP) 的梯度惩罚

这是 Wasserstein GAN 改进训练的实现。

WGAN 建议削减权重以对鉴别器网络强制执行 Lipschitz 限制(评论家)。这个和其他权重约束,如L2标准削减、权重标准化、L1、L2权重衰减都有问题:

1.限制鉴别器的容量 2.分解和消失渐变(不带批量归一化)。

论文《改进了 Wasserstein GaN 的训练》提出了改进 Lipschitz 约束的更好方法,即梯度惩罚。

惩罚重量在哪里

也就是说,我们尽量保持梯度范数接近

在这个实现中,我们设置

以下是使用梯度惩罚的实验的代码

46import torch
47import torch.autograd
48
49from labml_helpers.module import Module

梯度惩罚

52class GradientPenalty(Module):
  • x
  • f
  • 因为我们为这个实现做好了准备。

    57    def forward(self, x: torch.Tensor, f: torch.Tensor):

    获取批次大小

    67        batch_size = x.shape[0]

    计算相对于的梯度grad_outputs 设置为 1,因为我们想要梯度,我们需要创建和保留图形,因为我们必须计算相对于此损失的权重的梯度。

    73        gradients, *_ = torch.autograd.grad(outputs=f,
    74                                            inputs=x,
    75                                            grad_outputs=f.new_ones(f.shape),
    76                                            create_graph=True)

    重塑梯度以计算范数

    79        gradients = gradients.reshape(batch_size, -1)

    计算常数

    81        norm = gradients.norm(2, dim=-1)

    退还损失

    83        return torch.mean((norm - 1) ** 2)