这是 Wasserstein GAN 改进训练的实现。
WGAN 建议削减权重以对鉴别器网络强制执行 Lipschitz 限制(评论家)。这个和其他权重约束,如L2标准削减、权重标准化、L1、L2权重衰减都有问题:
1.限制鉴别器的容量 2.分解和消失渐变(不带批量归一化)。
论文《改进了 Wasserstein GaN 的训练》提出了改进 Lipschitz 约束的更好方法,即梯度惩罚。
惩罚重量在哪里
也就是说,我们尽量保持梯度范数接近。
在这个实现中,我们设置。
以下是使用梯度惩罚的实验的代码。
46import torch
47import torch.autograd
48
49from labml_helpers.module import Module
52class GradientPenalty(Module):
57 def forward(self, x: torch.Tensor, f: torch.Tensor):
获取批次大小
67 batch_size = x.shape[0]
计算相对于的梯度。grad_outputs
设置为 1,因为我们想要梯度,我们需要创建和保留图形,因为我们必须计算相对于此损失的权重的梯度。
73 gradients, *_ = torch.autograd.grad(outputs=f,
74 inputs=x,
75 grad_outputs=f.new_ones(f.shape),
76 create_graph=True)
重塑梯度以计算范数
79 gradients = gradients.reshape(batch_size, -1)
计算常数
81 norm = gradients.norm(2, dim=-1)
退还损失
83 return torch.mean((norm - 1) ** 2)