10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28 classname = m.__class__.__name__
29 if classname.find('Linear') != -1:
30 nn.init.normal_(m.weight.data, 0.0, 0.02)
31 elif classname.find('BatchNorm') != -1:
32 nn.init.normal_(m.weight.data, 1.0, 0.02)
33 nn.init.constant_(m.bias.data, 0)
This has three linear layers of increasing size with LeakyReLU
activations.
The final layer has a $tanh$ activation.
36class Generator(Module):
44 def __init__(self):
45 super().__init__()
46 layer_sizes = [256, 512, 1024]
47 layers = []
48 d_prev = 100
49 for size in layer_sizes:
50 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51 d_prev = size
52
53 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55 self.apply(weights_init)
57 def forward(self, x):
58 return self.layers(x).view(x.shape[0], 1, 28, 28)
This has three linear layers of decreasing size with LeakyReLU
activations.
The final layer has a single output that gives the logit of whether input
is real or fake. You can get the probability by calculating the sigmoid of it.
61class Discriminator(Module):
70 def __init__(self):
71 super().__init__()
72 layer_sizes = [1024, 512, 256]
73 layers = []
74 d_prev = 28 * 28
75 for size in layer_sizes:
76 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77 d_prev = size
78
79 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80 self.apply(weights_init)
82 def forward(self, x):
83 return self.layers(x.view(x.shape[0], -1))
86class Configs(MNISTConfigs, TrainValidConfigs):
87 device: torch.device = DeviceConfigs()
88 epochs: int = 10
89
90 is_save_models = True
91 discriminator: Module
92 generator: Module
93 generator_optimizer: torch.optim.Adam
94 discriminator_optimizer: torch.optim.Adam
95 generator_loss: GeneratorLogitsLoss
96 discriminator_loss: DiscriminatorLogitsLoss
97 label_smoothing: float = 0.2
98 discriminator_k: int = 1
100 def init(self):
101 self.state_modules = []
102 self.generator = Generator().to(self.device)
103 self.discriminator = Discriminator().to(self.device)
104 self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device)
105 self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device)
106
107 hook_model_outputs(self.mode, self.generator, 'generator')
108 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
109 tracker.set_scalar("loss.generator.*", True)
110 tracker.set_scalar("loss.discriminator.*", True)
111 tracker.set_image("generated", True, 1 / 100)
113 def step(self, batch: Any, batch_idx: BatchIndex):
114 self.generator.train(self.mode.is_train)
115 self.discriminator.train(self.mode.is_train)
116
117 data, target = batch[0].to(self.device), batch[1].to(self.device)
Increment step in training mode
120 if self.mode.is_train:
121 tracker.add_global_step(len(data))
Train the discriminator
124 with monit.section("discriminator"):
125 for _ in range(self.discriminator_k):
126 latent = torch.randn(data.shape[0], 100, device=self.device)
127 logits_true = self.discriminator(data)
128 logits_false = self.discriminator(self.generator(latent).detach())
129 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
130 loss = loss_true + loss_false
Log stuff
133 tracker.add("loss.discriminator.true.", loss_true)
134 tracker.add("loss.discriminator.false.", loss_false)
135 tracker.add("loss.discriminator.", loss)
Train
138 if self.mode.is_train:
139 self.discriminator_optimizer.zero_grad()
140 loss.backward()
141 if batch_idx.is_last:
142 tracker.add('discriminator', self.discriminator)
143 self.discriminator_optimizer.step()
Train the generator
146 with monit.section("generator"):
147 latent = torch.randn(data.shape[0], 100, device=self.device)
148 generated_images = self.generator(latent)
149 logits = self.discriminator(generated_images)
150 loss = self.generator_loss(logits)
Log stuff
153 tracker.add('generated', generated_images[0:5])
154 tracker.add("loss.generator.", loss)
Train
157 if self.mode.is_train:
158 self.generator_optimizer.zero_grad()
159 loss.backward()
160 if batch_idx.is_last:
161 tracker.add('generator', self.generator)
162 self.generator_optimizer.step()
163
164 tracker.save()
167@option(Configs.dataset_transforms)
168def mnist_transforms():
169 return transforms.Compose([
170 transforms.ToTensor(),
171 transforms.Normalize((0.5,), (0.5,))
172 ])
173
174
175@option(Configs.discriminator_optimizer)
176def _discriminator_optimizer(c: Configs):
177 opt_conf = OptimizerConfigs()
178 opt_conf.optimizer = 'Adam'
179 opt_conf.parameters = c.discriminator.parameters()
180 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient,
$\beta_$ to
0.5is important.
Default of
0.9` fails.
184 opt_conf.betas = (0.5, 0.999)
185 return opt_conf
188@option(Configs.generator_optimizer)
189def _generator_optimizer(c: Configs):
190 opt_conf = OptimizerConfigs()
191 opt_conf.optimizer = 'Adam'
192 opt_conf.parameters = c.generator.parameters()
193 opt_conf.learning_rate = 2.5e-4
Setting exponent decay rate for first moment of gradient,
$\beta_$ to
0.5is important.
Default of
0.9` fails.
197 opt_conf.betas = (0.5, 0.999)
198 return opt_conf
201def main():
202 conf = Configs()
203 experiment.create(name='mnist_gan', comment='test')
204 experiment.configs(conf,
205 {'label_smoothing': 0.01})
206 with experiment.start():
207 conf.run()
208
209
210if __name__ == '__main__':
211 main()