This is an implementation of Improved Training of Wasserstein GANs.

WGAN suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:

- Limiting the capacity of the discriminator
- Exploding and vanishing gradients (without Batch Normalization).

The paper Improved Training of Wasserstein GANs proposal a better way to improve Lipschitz constraint, a gradient penalty.

where $\lambda$ is the penalty weight and

That is we try to keep the gradient norm $\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2$ close to $1$.

In this implementation we set $\epsilon = 1$.

Here is the code for an experiment that uses gradient penalty.

```
46import torch
47import torch.autograd
48
49from labml_helpers.module import Module
```

`52class GradientPenalty(Module):`

`x`

is $x \sim \mathbb{P}_r$`f`

is $D(x)$

$\hat{x} \leftarrow x$ since we set $\epsilon = 1$ for this implementation.

`57 def __call__(self, x: torch.Tensor, f: torch.Tensor):`

Get batch size

`67 batch_size = x.shape[0]`

Calculate gradients of $D(x)$ with respect to $x$.
`grad_outputs`

is set to ones since we want the gradients of $D(x)$,
and we need to create and retain graph since we have to compute gradients
with respect to weight on this loss.

```
73 gradients, *_ = torch.autograd.grad(outputs=f,
74 inputs=x,
75 grad_outputs=f.new_ones(f.shape),
76 create_graph=True)
```

Reshape gradients to calculate the norm

`79 gradients = gradients.reshape(batch_size, -1)`

Calculate the norm $\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2$

`81 norm = gradients.norm(2, dim=-1)`

Return the loss $\big(\Vert \nabla_{\hat{x}} D(\hat{x}) \Vert_2 - 1\big)^2$

`83 return torch.mean((norm - 1) ** 2)`