Generative Adversarial Networks (GAN)

This is an implementation of Generative Adversarial Networks.

The generator, $G(\pmb{z}; \theta_g)$ generates samples that match the distribution of data, while the discriminator, $D(\pmb{x}; \theta_g)$ gives the probability that $\pmb{x}$ came from data rather than $G$.

We train $D$ and $G$ simultaneously on a two-player min-max game with value function $V(G, D)$.

$p_{data}(\pmb{x})$ is the probability distribution over data, whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to gaussian noise.

This file defines the loss functions. Here is an MNIST example with two multilayer perceptron for the generator and discriminator.

34import torch
35import torch.nn as nn
39from labml_helpers.module import Module

Discriminator Loss

Discriminator should ascend on the gradient,

$m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch. $\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.

42class DiscriminatorLogitsLoss(Module):
57    def __init__(self, smoothing: float = 0.2):
58        super().__init__()

We use PyTorch Binary Cross Entropy Loss, which is $-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$, where $y$ are the labels and $\hat{y}$ are the predictions. Note the negative sign. We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$ and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$ Then descending on the sum of these is the same as ascending on the above gradient.

BCEWithLogitsLoss combines softmax and binary cross entropy loss.

69        self.loss_true = nn.BCEWithLogitsLoss()
70        self.loss_false = nn.BCEWithLogitsLoss()

We use label smoothing because it seems to work better in some cases

73        self.smoothing = smoothing

Labels are registered as buffered and persistence is set to False.

76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)

logits_true are logits from $D(\pmb{x}^{(i)})$ and logits_false are logits from $D(G(\pmb{z}^{(i)}))$

79    def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84        if len(logits_true) > len(self.labels_true):
85            self.register_buffer("labels_true",
86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87        if len(logits_false) > len(self.labels_false):
88            self.register_buffer("labels_false",
89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))

Generator Loss

Generator should descend on the gradient,

95class GeneratorLogitsLoss(Module):
105    def __init__(self, smoothing: float = 0.2):
106        super().__init__()
107        self.loss_true = nn.BCEWithLogitsLoss()
108        self.smoothing = smoothing

We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$ Then descending on this loss is the same as descending on the above gradient.

112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114    def __call__(self, logits: torch.Tensor):
115        if len(logits) > len(self.fake_labels):
116            self.register_buffer("fake_labels",
117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
119        return self.loss_true(logits, self.fake_labels[:len(logits)])

Create smoothed labels

122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)