Generative Adversarial Networks experiment with MNIST

10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)

Simple MLP Generator

This has three linear layers of increasing size with LeakyReLU activations. The final layer has a activation.

36class Generator(Module):
44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)
57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)

Simple MLP Discriminator

This has three linear layers of decreasing size with LeakyReLU activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.

61class Discriminator(Module):
70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)
82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))

Configurations

This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.

86class Configs(MNISTConfigs, TrainValidConfigs):
94    device: torch.device = DeviceConfigs()
95    dataset_transforms = 'mnist_gan_transforms'
96    epochs: int = 10
97
98    is_save_models = True
99    discriminator: Module = 'mlp'
100    generator: Module = 'mlp'
101    generator_optimizer: torch.optim.Adam
102    discriminator_optimizer: torch.optim.Adam
103    generator_loss: GeneratorLogitsLoss = 'original'
104    discriminator_loss: DiscriminatorLogitsLoss = 'original'
105    label_smoothing: float = 0.2
106    discriminator_k: int = 1

Initializations

108    def init(self):
112        self.state_modules = []
113
114        hook_model_outputs(self.mode, self.generator, 'generator')
115        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116        tracker.set_scalar("loss.generator.*", True)
117        tracker.set_scalar("loss.discriminator.*", True)
118        tracker.set_image("generated", True, 1 / 100)

120    def sample_z(self, batch_size: int):
124        return torch.randn(batch_size, 100, device=self.device)

Take a training step

126    def step(self, batch: Any, batch_idx: BatchIndex):

Set model states

132        self.generator.train(self.mode.is_train)
133        self.discriminator.train(self.mode.is_train)

Get MNIST images

136        data = batch[0].to(self.device)

Increment step in training mode

139        if self.mode.is_train:
140            tracker.add_global_step(len(data))

Train the discriminator

143        with monit.section("discriminator"):

Get discriminator loss

145            loss = self.calc_discriminator_loss(data)

Train

148            if self.mode.is_train:
149                self.discriminator_optimizer.zero_grad()
150                loss.backward()
151                if batch_idx.is_last:
152                    tracker.add('discriminator', self.discriminator)
153                self.discriminator_optimizer.step()

Train the generator once in every discriminator_k

156        if batch_idx.is_interval(self.discriminator_k):
157            with monit.section("generator"):
158                loss = self.calc_generator_loss(data.shape[0])

Train

161                if self.mode.is_train:
162                    self.generator_optimizer.zero_grad()
163                    loss.backward()
164                    if batch_idx.is_last:
165                        tracker.add('generator', self.generator)
166                    self.generator_optimizer.step()
167
168        tracker.save()

Calculate discriminator loss

170    def calc_discriminator_loss(self, data):
174        latent = self.sample_z(data.shape[0])
175        logits_true = self.discriminator(data)
176        logits_false = self.discriminator(self.generator(latent).detach())
177        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178        loss = loss_true + loss_false

Log stuff

181        tracker.add("loss.discriminator.true.", loss_true)
182        tracker.add("loss.discriminator.false.", loss_false)
183        tracker.add("loss.discriminator.", loss)
184
185        return loss

Calculate generator loss

187    def calc_generator_loss(self, batch_size: int):
191        latent =  self.sample_z(batch_size)
192        generated_images = self.generator(latent)
193        logits = self.discriminator(generated_images)
194        loss = self.generator_loss(logits)

Log stuff

197        tracker.add('generated', generated_images[0:6])
198        tracker.add("loss.generator.", loss)
199
200        return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207    return transforms.Compose([
208        transforms.ToTensor(),
209        transforms.Normalize((0.5,), (0.5,))
210    ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215    opt_conf = OptimizerConfigs()
216    opt_conf.optimizer = 'Adam'
217    opt_conf.parameters = c.discriminator.parameters()
218    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, to 0.5 is important. Default of 0.9 fails.

222    opt_conf.betas = (0.5, 0.999)
223    return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228    opt_conf = OptimizerConfigs()
229    opt_conf.optimizer = 'Adam'
230    opt_conf.parameters = c.generator.parameters()
231    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, to 0.5 is important. Default of 0.9 fails.

235    opt_conf.betas = (0.5, 0.999)
236    return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246    conf = Configs()
247    experiment.create(name='mnist_gan', comment='test')
248    experiment.configs(conf,
249                       {'label_smoothing': 0.01})
250    with experiment.start():
251        conf.run()
252
253
254if __name__ == '__main__':
255    main()