10from typing import Any
11
12from torchvision import transforms
13
14import torch
15import torch.nn as nn
16import torch.utils.data
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
20from labml_nn.helpers.datasets import MNISTConfigs
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.optimizer import OptimizerConfigs
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
26def weights_init(m):
27 classname = m.__class__.__name__
28 if classname.find('Linear') != -1:
29 nn.init.normal_(m.weight.data, 0.0, 0.02)
30 elif classname.find('BatchNorm') != -1:
31 nn.init.normal_(m.weight.data, 1.0, 0.02)
32 nn.init.constant_(m.bias.data, 0)
This has three linear layers of increasing size with LeakyReLU
activations. The final layer has a activation.
35class Generator(nn.Module):
43 def __init__(self):
44 super().__init__()
45 layer_sizes = [256, 512, 1024]
46 layers = []
47 d_prev = 100
48 for size in layer_sizes:
49 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
50 d_prev = size
51
52 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
53
54 self.apply(weights_init)
56 def forward(self, x):
57 return self.layers(x).view(x.shape[0], 1, 28, 28)
This has three linear layers of decreasing size with LeakyReLU
activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.
60class Discriminator(nn.Module):
69 def __init__(self):
70 super().__init__()
71 layer_sizes = [1024, 512, 256]
72 layers = []
73 d_prev = 28 * 28
74 for size in layer_sizes:
75 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
76 d_prev = size
77
78 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
79 self.apply(weights_init)
81 def forward(self, x):
82 return self.layers(x.view(x.shape[0], -1))
This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.
85class Configs(MNISTConfigs, TrainValidConfigs):
93 device: torch.device = DeviceConfigs()
94 dataset_transforms = 'mnist_gan_transforms'
95 epochs: int = 10
96
97 is_save_models = True
98 discriminator: nn.Module = 'mlp'
99 generator: nn.Module = 'mlp'
100 generator_optimizer: torch.optim.Adam
101 discriminator_optimizer: torch.optim.Adam
102 generator_loss: GeneratorLogitsLoss = 'original'
103 discriminator_loss: DiscriminatorLogitsLoss = 'original'
104 label_smoothing: float = 0.2
105 discriminator_k: int = 1
Initializations
107 def init(self):
111 self.state_modules = []
112
113 tracker.set_scalar("loss.generator.*", True)
114 tracker.set_scalar("loss.discriminator.*", True)
115 tracker.set_image("generated", True, 1 / 100)
117 def sample_z(self, batch_size: int):
121 return torch.randn(batch_size, 100, device=self.device)
Take a training step
123 def step(self, batch: Any, batch_idx: BatchIndex):
Set model states
129 self.generator.train(self.mode.is_train)
130 self.discriminator.train(self.mode.is_train)
Get MNIST images
133 data = batch[0].to(self.device)
Increment step in training mode
136 if self.mode.is_train:
137 tracker.add_global_step(len(data))
Train the discriminator
140 with monit.section("discriminator"):
Get discriminator loss
142 loss = self.calc_discriminator_loss(data)
Train
145 if self.mode.is_train:
146 self.discriminator_optimizer.zero_grad()
147 loss.backward()
148 if batch_idx.is_last:
149 tracker.add('discriminator', self.discriminator)
150 self.discriminator_optimizer.step()
Train the generator once in every discriminator_k
153 if batch_idx.is_interval(self.discriminator_k):
154 with monit.section("generator"):
155 loss = self.calc_generator_loss(data.shape[0])
Train
158 if self.mode.is_train:
159 self.generator_optimizer.zero_grad()
160 loss.backward()
161 if batch_idx.is_last:
162 tracker.add('generator', self.generator)
163 self.generator_optimizer.step()
164
165 tracker.save()
Calculate discriminator loss
167 def calc_discriminator_loss(self, data):
171 latent = self.sample_z(data.shape[0])
172 logits_true = self.discriminator(data)
173 logits_false = self.discriminator(self.generator(latent).detach())
174 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
175 loss = loss_true + loss_false
Log stuff
178 tracker.add("loss.discriminator.true.", loss_true)
179 tracker.add("loss.discriminator.false.", loss_false)
180 tracker.add("loss.discriminator.", loss)
181
182 return loss
Calculate generator loss
184 def calc_generator_loss(self, batch_size: int):
188 latent = self.sample_z(batch_size)
189 generated_images = self.generator(latent)
190 logits = self.discriminator(generated_images)
191 loss = self.generator_loss(logits)
Log stuff
194 tracker.add('generated', generated_images[0:6])
195 tracker.add("loss.generator.", loss)
196
197 return loss
200@option(Configs.dataset_transforms)
201def mnist_gan_transforms():
202 return transforms.Compose([
203 transforms.ToTensor(),
204 transforms.Normalize((0.5,), (0.5,))
205 ])
206
207
208@option(Configs.discriminator_optimizer)
209def _discriminator_optimizer(c: Configs):
210 opt_conf = OptimizerConfigs()
211 opt_conf.optimizer = 'Adam'
212 opt_conf.parameters = c.discriminator.parameters()
213 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient, to 0.5
is important. Default of 0.9
fails.
217 opt_conf.betas = (0.5, 0.999)
218 return opt_conf
221@option(Configs.generator_optimizer)
222def _generator_optimizer(c: Configs):
223 opt_conf = OptimizerConfigs()
224 opt_conf.optimizer = 'Adam'
225 opt_conf.parameters = c.generator.parameters()
226 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient, to 0.5
is important. Default of 0.9
fails.
230 opt_conf.betas = (0.5, 0.999)
231 return opt_conf
232
233
234calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
235calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
236calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
237calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
240def main():
241 conf = Configs()
242 experiment.create(name='mnist_gan', comment='test')
243 experiment.configs(conf,
244 {'label_smoothing': 0.01})
245 with experiment.start():
246 conf.run()
247
248
249if __name__ == '__main__':
250 main()