This is an implementation of Improved Training of Wasserstein GANs.
WGAN suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:
1. Limiting the capacity of the discriminator 2. Exploding and vanishing gradients (without Batch Normalization).
The paper Improved Training of Wasserstein GANs proposal a better way to improve Lipschitz constraint, a gradient penalty.
where is the penalty weight and
That is we try to keep the gradient norm close to .
In this implementation we set .
Here is the code for an experiment that uses gradient penalty.
46import torch 47import torch.autograd 48 49from labml_helpers.module import Module
since we set for this implementation.
57 def forward(self, x: torch.Tensor, f: torch.Tensor):
Get batch size
67 batch_size = x.shape
Calculate gradients of with respect to .
is set to ones since we want the gradients of , and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.
73 gradients, *_ = torch.autograd.grad(outputs=f, 74 inputs=x, 75 grad_outputs=f.new_ones(f.shape), 76 create_graph=True)
Reshape gradients to calculate the norm
79 gradients = gradients.reshape(batch_size, -1)
Calculate the norm
81 norm = gradients.norm(2, dim=-1)
Return the loss
83 return torch.mean((norm - 1) ** 2)