10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28 classname = m.__class__.__name__
29 if classname.find('Linear') != -1:
30 nn.init.normal_(m.weight.data, 0.0, 0.02)
31 elif classname.find('BatchNorm') != -1:
32 nn.init.normal_(m.weight.data, 1.0, 0.02)
33 nn.init.constant_(m.bias.data, 0)
This has three linear layers of increasing size with LeakyReLU
activations. The final layer has a activation.
36class Generator(Module):
44 def __init__(self):
45 super().__init__()
46 layer_sizes = [256, 512, 1024]
47 layers = []
48 d_prev = 100
49 for size in layer_sizes:
50 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51 d_prev = size
52
53 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55 self.apply(weights_init)
57 def forward(self, x):
58 return self.layers(x).view(x.shape[0], 1, 28, 28)
This has three linear layers of decreasing size with LeakyReLU
activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.
61class Discriminator(Module):
70 def __init__(self):
71 super().__init__()
72 layer_sizes = [1024, 512, 256]
73 layers = []
74 d_prev = 28 * 28
75 for size in layer_sizes:
76 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77 d_prev = size
78
79 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80 self.apply(weights_init)
82 def forward(self, x):
83 return self.layers(x.view(x.shape[0], -1))
This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.
86class Configs(MNISTConfigs, TrainValidConfigs):
94 device: torch.device = DeviceConfigs()
95 dataset_transforms = 'mnist_gan_transforms'
96 epochs: int = 10
97
98 is_save_models = True
99 discriminator: Module = 'mlp'
100 generator: Module = 'mlp'
101 generator_optimizer: torch.optim.Adam
102 discriminator_optimizer: torch.optim.Adam
103 generator_loss: GeneratorLogitsLoss = 'original'
104 discriminator_loss: DiscriminatorLogitsLoss = 'original'
105 label_smoothing: float = 0.2
106 discriminator_k: int = 1
Initializations
108 def init(self):
112 self.state_modules = []
113
114 hook_model_outputs(self.mode, self.generator, 'generator')
115 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116 tracker.set_scalar("loss.generator.*", True)
117 tracker.set_scalar("loss.discriminator.*", True)
118 tracker.set_image("generated", True, 1 / 100)
120 def sample_z(self, batch_size: int):
124 return torch.randn(batch_size, 100, device=self.device)
Take a training step
126 def step(self, batch: Any, batch_idx: BatchIndex):
Set model states
132 self.generator.train(self.mode.is_train)
133 self.discriminator.train(self.mode.is_train)
Get MNIST images
136 data = batch[0].to(self.device)
Increment step in training mode
139 if self.mode.is_train:
140 tracker.add_global_step(len(data))
Train the discriminator
143 with monit.section("discriminator"):
Get discriminator loss
145 loss = self.calc_discriminator_loss(data)
Train
148 if self.mode.is_train:
149 self.discriminator_optimizer.zero_grad()
150 loss.backward()
151 if batch_idx.is_last:
152 tracker.add('discriminator', self.discriminator)
153 self.discriminator_optimizer.step()
Train the generator once in every discriminator_k
156 if batch_idx.is_interval(self.discriminator_k):
157 with monit.section("generator"):
158 loss = self.calc_generator_loss(data.shape[0])
Train
161 if self.mode.is_train:
162 self.generator_optimizer.zero_grad()
163 loss.backward()
164 if batch_idx.is_last:
165 tracker.add('generator', self.generator)
166 self.generator_optimizer.step()
167
168 tracker.save()
Calculate discriminator loss
170 def calc_discriminator_loss(self, data):
174 latent = self.sample_z(data.shape[0])
175 logits_true = self.discriminator(data)
176 logits_false = self.discriminator(self.generator(latent).detach())
177 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178 loss = loss_true + loss_false
Log stuff
181 tracker.add("loss.discriminator.true.", loss_true)
182 tracker.add("loss.discriminator.false.", loss_false)
183 tracker.add("loss.discriminator.", loss)
184
185 return loss
Calculate generator loss
187 def calc_generator_loss(self, batch_size: int):
191 latent = self.sample_z(batch_size)
192 generated_images = self.generator(latent)
193 logits = self.discriminator(generated_images)
194 loss = self.generator_loss(logits)
Log stuff
197 tracker.add('generated', generated_images[0:6])
198 tracker.add("loss.generator.", loss)
199
200 return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207 return transforms.Compose([
208 transforms.ToTensor(),
209 transforms.Normalize((0.5,), (0.5,))
210 ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215 opt_conf = OptimizerConfigs()
216 opt_conf.optimizer = 'Adam'
217 opt_conf.parameters = c.discriminator.parameters()
218 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient, to 0.5
is important. Default of 0.9
fails.
222 opt_conf.betas = (0.5, 0.999)
223 return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228 opt_conf = OptimizerConfigs()
229 opt_conf.optimizer = 'Adam'
230 opt_conf.parameters = c.generator.parameters()
231 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient, to 0.5
is important. Default of 0.9
fails.
235 opt_conf.betas = (0.5, 0.999)
236 return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246 conf = Configs()
247 experiment.create(name='mnist_gan', comment='test')
248 experiment.configs(conf,
249 {'label_smoothing': 0.01})
250 with experiment.start():
251 conf.run()
252
253
254if __name__ == '__main__':
255 main()