Generative Adversarial Networks experiment with MNIST

10from typing import Any
11
12from torchvision import transforms
13
14import torch
15import torch.nn as nn
16import torch.utils.data
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
20from labml_nn.helpers.datasets import MNISTConfigs
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.optimizer import OptimizerConfigs
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
26def weights_init(m):
27    classname = m.__class__.__name__
28    if classname.find('Linear') != -1:
29        nn.init.normal_(m.weight.data, 0.0, 0.02)
30    elif classname.find('BatchNorm') != -1:
31        nn.init.normal_(m.weight.data, 1.0, 0.02)
32        nn.init.constant_(m.bias.data, 0)

Simple MLP Generator

This has three linear layers of increasing size with LeakyReLU activations. The final layer has a activation.

35class Generator(nn.Module):
43    def __init__(self):
44        super().__init__()
45        layer_sizes = [256, 512, 1024]
46        layers = []
47        d_prev = 100
48        for size in layer_sizes:
49            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
50            d_prev = size
51
52        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
53
54        self.apply(weights_init)
56    def forward(self, x):
57        return self.layers(x).view(x.shape[0], 1, 28, 28)

Simple MLP Discriminator

This has three linear layers of decreasing size with LeakyReLU activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.

60class Discriminator(nn.Module):
69    def __init__(self):
70        super().__init__()
71        layer_sizes = [1024, 512, 256]
72        layers = []
73        d_prev = 28 * 28
74        for size in layer_sizes:
75            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
76            d_prev = size
77
78        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
79        self.apply(weights_init)
81    def forward(self, x):
82        return self.layers(x.view(x.shape[0], -1))

Configurations

This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.

85class Configs(MNISTConfigs, TrainValidConfigs):
93    device: torch.device = DeviceConfigs()
94    dataset_transforms = 'mnist_gan_transforms'
95    epochs: int = 10
96
97    is_save_models = True
98    discriminator: nn.Module = 'mlp'
99    generator: nn.Module = 'mlp'
100    generator_optimizer: torch.optim.Adam
101    discriminator_optimizer: torch.optim.Adam
102    generator_loss: GeneratorLogitsLoss = 'original'
103    discriminator_loss: DiscriminatorLogitsLoss = 'original'
104    label_smoothing: float = 0.2
105    discriminator_k: int = 1

Initializations

107    def init(self):
111        self.state_modules = []
112
113        tracker.set_scalar("loss.generator.*", True)
114        tracker.set_scalar("loss.discriminator.*", True)
115        tracker.set_image("generated", True, 1 / 100)

117    def sample_z(self, batch_size: int):
121        return torch.randn(batch_size, 100, device=self.device)

Take a training step

123    def step(self, batch: Any, batch_idx: BatchIndex):

Set model states

129        self.generator.train(self.mode.is_train)
130        self.discriminator.train(self.mode.is_train)

Get MNIST images

133        data = batch[0].to(self.device)

Increment step in training mode

136        if self.mode.is_train:
137            tracker.add_global_step(len(data))

Train the discriminator

140        with monit.section("discriminator"):

Get discriminator loss

142            loss = self.calc_discriminator_loss(data)

Train

145            if self.mode.is_train:
146                self.discriminator_optimizer.zero_grad()
147                loss.backward()
148                if batch_idx.is_last:
149                    tracker.add('discriminator', self.discriminator)
150                self.discriminator_optimizer.step()

Train the generator once in every discriminator_k

153        if batch_idx.is_interval(self.discriminator_k):
154            with monit.section("generator"):
155                loss = self.calc_generator_loss(data.shape[0])

Train

158                if self.mode.is_train:
159                    self.generator_optimizer.zero_grad()
160                    loss.backward()
161                    if batch_idx.is_last:
162                        tracker.add('generator', self.generator)
163                    self.generator_optimizer.step()
164
165        tracker.save()

Calculate discriminator loss

167    def calc_discriminator_loss(self, data):
171        latent = self.sample_z(data.shape[0])
172        logits_true = self.discriminator(data)
173        logits_false = self.discriminator(self.generator(latent).detach())
174        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
175        loss = loss_true + loss_false

Log stuff

178        tracker.add("loss.discriminator.true.", loss_true)
179        tracker.add("loss.discriminator.false.", loss_false)
180        tracker.add("loss.discriminator.", loss)
181
182        return loss

Calculate generator loss

184    def calc_generator_loss(self, batch_size: int):
188        latent = self.sample_z(batch_size)
189        generated_images = self.generator(latent)
190        logits = self.discriminator(generated_images)
191        loss = self.generator_loss(logits)

Log stuff

194        tracker.add('generated', generated_images[0:6])
195        tracker.add("loss.generator.", loss)
196
197        return loss
200@option(Configs.dataset_transforms)
201def mnist_gan_transforms():
202    return transforms.Compose([
203        transforms.ToTensor(),
204        transforms.Normalize((0.5,), (0.5,))
205    ])
206
207
208@option(Configs.discriminator_optimizer)
209def _discriminator_optimizer(c: Configs):
210    opt_conf = OptimizerConfigs()
211    opt_conf.optimizer = 'Adam'
212    opt_conf.parameters = c.discriminator.parameters()
213    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, to 0.5 is important. Default of 0.9 fails.

217    opt_conf.betas = (0.5, 0.999)
218    return opt_conf
221@option(Configs.generator_optimizer)
222def _generator_optimizer(c: Configs):
223    opt_conf = OptimizerConfigs()
224    opt_conf.optimizer = 'Adam'
225    opt_conf.parameters = c.generator.parameters()
226    opt_conf.learning_rate = 2.5e-4

Setting exponent decay rate for first moment of gradient, to 0.5 is important. Default of 0.9 fails.

230    opt_conf.betas = (0.5, 0.999)
231    return opt_conf
232
233
234calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
235calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
236calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
237calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
240def main():
241    conf = Configs()
242    experiment.create(name='mnist_gan', comment='test')
243    experiment.configs(conf,
244                       {'label_smoothing': 0.01})
245    with experiment.start():
246        conf.run()
247
248
249if __name__ == '__main__':
250    main()