This is an implementation of Generative Adversarial Networks.
The generator, generates samples that match the distribution of data, while the discriminator, gives the probability that came from data rather than .
We train and simultaneously on a two-player min-max game with value function .
is the probability distribution over data, whilst probability distribution of , which is set to gaussian noise.
This file defines the loss functions. Here is an MNIST example with two multilayer perceptron for the generator and discriminator.
34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module
Discriminator should ascend on the gradient,
is the mini-batch size and is used to index samples in the mini-batch. are samples from and are samples from .
42class DiscriminatorLogitsLoss(Module):
57 def __init__(self, smoothing: float = 0.2):
58 super().__init__()
We use PyTorch Binary Cross Entropy Loss, which is , where are the labels and are the predictions. Note the negative sign. We use labels equal to for from and labels equal to for from Then descending on the sum of these is the same as ascending on the above gradient.
BCEWithLogitsLoss
combines softmax and binary cross entropy loss.
69 self.loss_true = nn.BCEWithLogitsLoss()
70 self.loss_false = nn.BCEWithLogitsLoss()
We use label smoothing because it seems to work better in some cases
73 self.smoothing = smoothing
Labels are registered as buffered and persistence is set to False
.
76 self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77 self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
logits_true
are logits from and logits_false
are logits from
79 def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84 if len(logits_true) > len(self.labels_true):
85 self.register_buffer("labels_true",
86 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87 if len(logits_false) > len(self.labels_false):
88 self.register_buffer("labels_false",
89 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91 return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92 self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
95class GeneratorLogitsLoss(Module):
105 def __init__(self, smoothing: float = 0.2):
106 super().__init__()
107 self.loss_true = nn.BCEWithLogitsLoss()
108 self.smoothing = smoothing
We use labels equal to for from Then descending on this loss is the same as descending on the above gradient.
112 self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114 def forward(self, logits: torch.Tensor):
115 if len(logits) > len(self.fake_labels):
116 self.register_buffer("fake_labels",
117 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119 return self.loss_true(logits, self.fake_labels[:len(logits)])
Create smoothed labels
122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126 return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)