Generative Adversarial Networks experiment with MNIST

10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)

Simple MLP Generator

This has three linear layers of increasing size with LeakyReLU activations. The final layer has a $tanh$ activation.

36class Generator(Module):
44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)
57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)

Simple MLP Discriminator

This has three linear layers of decreasing size with LeakyReLU activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.

61class Discriminator(Module):
70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)
82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))
86class Configs(MNISTConfigs, TrainValidConfigs):
87    device: torch.device = DeviceConfigs()
88    epochs: int = 10
89
90    is_save_models = True
91    discriminator: Module
92    generator: Module
93    generator_optimizer: torch.optim.Adam
94    discriminator_optimizer: torch.optim.Adam
95    generator_loss: GeneratorLogitsLoss
96    discriminator_loss: DiscriminatorLogitsLoss
97    label_smoothing: float = 0.2
98    discriminator_k: int = 1
100    def init(self):
101        self.state_modules = []
102        self.generator = Generator().to(self.device)
103        self.discriminator = Discriminator().to(self.device)
104        self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device)
105        self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device)
106
107        hook_model_outputs(self.mode, self.generator, 'generator')
108        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
109        tracker.set_scalar("loss.generator.*", True)
110        tracker.set_scalar("loss.discriminator.*", True)
111        tracker.set_image("generated", True, 1 / 100)
113    def step(self, batch: Any, batch_idx: BatchIndex):
114        self.generator.train(self.mode.is_train)
115        self.discriminator.train(self.mode.is_train)
116
117        data, target = batch[0].to(self.device), batch[1].to(self.device)

Increment step in training mode

120        if self.mode.is_train:
121            tracker.add_global_step(len(data))

Train the discriminator

124        with monit.section("discriminator"):
125            for _ in range(self.discriminator_k):
126                latent = torch.randn(data.shape[0], 100, device=self.device)
127                logits_true = self.discriminator(data)
128                logits_false = self.discriminator(self.generator(latent).detach())
129                loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
130                loss = loss_true + loss_false

Log stuff

133                tracker.add("loss.discriminator.true.", loss_true)
134                tracker.add("loss.discriminator.false.", loss_false)
135                tracker.add("loss.discriminator.", loss)

Train

138                if self.mode.is_train:
139                    self.discriminator_optimizer.zero_grad()
140                    loss.backward()
141                    if batch_idx.is_last:
142                        tracker.add('discriminator', self.discriminator)
143                    self.discriminator_optimizer.step()

Train the generator

146        with monit.section("generator"):
147            latent = torch.randn(data.shape[0], 100, device=self.device)
148            generated_images = self.generator(latent)
149            logits = self.discriminator(generated_images)
150            loss = self.generator_loss(logits)

Log stuff

153            tracker.add('generated', generated_images[0:5])
154            tracker.add("loss.generator.", loss)

Train

157            if self.mode.is_train:
158                self.generator_optimizer.zero_grad()
159                loss.backward()
160                if batch_idx.is_last:
161                    tracker.add('generator', self.generator)
162                self.generator_optimizer.step()
163
164        tracker.save()
167@option(Configs.dataset_transforms)
168def mnist_transforms():
169    return transforms.Compose([
170        transforms.ToTensor(),
171        transforms.Normalize((0.5,), (0.5,))
172    ])
173
174
175@option(Configs.discriminator_optimizer)
176def _discriminator_optimizer(c: Configs):
177    opt_conf = OptimizerConfigs()
178    opt_conf.optimizer = 'Adam'
179    opt_conf.parameters = c.discriminator.parameters()
180    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, $\beta_$ to0.5is important. Default of0.9` fails.

184    opt_conf.betas = (0.5, 0.999)
185    return opt_conf
188@option(Configs.generator_optimizer)
189def _generator_optimizer(c: Configs):
190    opt_conf = OptimizerConfigs()
191    opt_conf.optimizer = 'Adam'
192    opt_conf.parameters = c.generator.parameters()
193    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, $\beta_$ to0.5is important. Default of0.9` fails.

197    opt_conf.betas = (0.5, 0.999)
198    return opt_conf
201def main():
202    conf = Configs()
203    experiment.create(name='mnist_gan', comment='test')
204    experiment.configs(conf,
205                       {'label_smoothing': 0.01})
206    with experiment.start():
207        conf.run()
208
209
210if __name__ == '__main__':
211    main()