Gradient Penalty for Wasserstein GAN (WGAN-GP)

This is an implementation of Improved Training of Wasserstein GANs.

WGAN suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:

  1. Limiting the capacity of the discriminator
  2. Exploding and vanishing gradients (without Batch Normalization).

The paper Improved Training of Wasserstein GANs proposal a better way to improve Lipschitz constraint, a gradient penalty.

where $\lambda$ is the penalty weight and

That is we try to keep the gradient norm $\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2$ close to $1$.

In this implementation we set $\epsilon = 1$.

Here is the code for an experiment that uses gradient penalty.

46import torch
47import torch.autograd
49from labml_helpers.module import Module

Gradient Penalty

52class GradientPenalty(Module):
  • x is $x \sim \mathbb{P}_r$
  • f is $D(x)$

$\hat{x} \leftarrow x$ since we set $\epsilon = 1$ for this implementation.

57    def __call__(self, x: torch.Tensor, f: torch.Tensor):

Get batch size

67        batch_size = x.shape[0]

Calculate gradients of $D(x)$ with respect to $x$. grad_outputs is set to ones since we want the gradients of $D(x)$, and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.

73        gradients, *_ = torch.autograd.grad(outputs=f,
74                                            inputs=x,
75                                            grad_outputs=f.new_ones(f.shape),
76                                            create_graph=True)

Reshape gradients to calculate the norm

79        gradients = gradients.reshape(batch_size, -1)

Calculate the norm $\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2$

81        norm = gradients.norm(2, dim=-1)

Return the loss $\big(\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2 - 1\big)^2$

83        return torch.mean((norm - 1) ** 2)