Wasserstein GAN (WGAN)

This is an implementation of Wasserstein GAN.

The original GAN loss is based on Jensen-Shannon (JS) divergence between the real distribution and generated distribution . The Wasserstein GAN is based on Earth Mover distance between these distributions.

is the set of all joint distributions, whose marginal probabilities are .

is the earth mover distance for a given joint distribution ( and are probabilities).

So is equal to the least earth mover distance for any joint distribution between the real distribution and generated distribution .

The paper shows that Jensen-Shannon (JS) divergence and other measures for the difference between two probability distributions are not smooth. And therefore if we are doing gradient descent on one of the probability distributions (parameterized) it will not converge.

Based on Kantorovich-Rubinstein duality,

where are all 1-Lipschitz functions.

That is, it is equal to the greatest difference among all 1-Lipschitz functions.

For -Lipschitz functions,

If all -Lipschitz functions can be represented as where is parameterized by ,

If is represented by a generator and is from a known distribution ,

Now to converge with we can gradient descent on to minimize above formula.

Similarly we can find by ascending on , while keeping bounded. One way to keep bounded is to clip all weights in the neural network that defines clipped within a range.

Here is the code to try this on a simple MNIST generation experiment.

Open In Colab

87import torch.utils.data
88from torch import nn
89from torch.nn import functional as F

Discriminator Loss

We want to find to maximize , so we minimize,

92class DiscriminatorLoss(nn.Module):
  • f_real is
  • f_fake is

This returns the a tuple with losses for and , which are later added. They are kept separate for logging.

103    def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):

We use ReLUs to clip the loss to keep range.

114        return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()

Generator Loss

We want to find to minimize The first component is independent of , so we minimize,

117class GeneratorLoss(nn.Module):
  • f_fake is
129    def forward(self, f_fake: torch.Tensor):
133        return -f_fake.mean()