10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28 classname = m.__class__.__name__
29 if classname.find('Linear') != -1:
30 nn.init.normal_(m.weight.data, 0.0, 0.02)
31 elif classname.find('BatchNorm') != -1:
32 nn.init.normal_(m.weight.data, 1.0, 0.02)
33 nn.init.constant_(m.bias.data, 0)
36class Generator(Module):
44 def __init__(self):
45 super().__init__()
46 layer_sizes = [256, 512, 1024]
47 layers = []
48 d_prev = 100
49 for size in layer_sizes:
50 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51 d_prev = size
52
53 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55 self.apply(weights_init)
57 def forward(self, x):
58 return self.layers(x).view(x.shape[0], 1, 28, 28)
61class Discriminator(Module):
70 def __init__(self):
71 super().__init__()
72 layer_sizes = [1024, 512, 256]
73 layers = []
74 d_prev = 28 * 28
75 for size in layer_sizes:
76 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77 d_prev = size
78
79 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80 self.apply(weights_init)
82 def forward(self, x):
83 return self.layers(x.view(x.shape[0], -1))
86class Configs(MNISTConfigs, TrainValidConfigs):
94 device: torch.device = DeviceConfigs()
95 dataset_transforms = 'mnist_gan_transforms'
96 epochs: int = 10
97
98 is_save_models = True
99 discriminator: Module = 'mlp'
100 generator: Module = 'mlp'
101 generator_optimizer: torch.optim.Adam
102 discriminator_optimizer: torch.optim.Adam
103 generator_loss: GeneratorLogitsLoss = 'original'
104 discriminator_loss: DiscriminatorLogitsLoss = 'original'
105 label_smoothing: float = 0.2
106 discriminator_k: int = 1
初始化
108 def init(self):
112 self.state_modules = []
113
114 hook_model_outputs(self.mode, self.generator, 'generator')
115 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116 tracker.set_scalar("loss.generator.*", True)
117 tracker.set_scalar("loss.discriminator.*", True)
118 tracker.set_image("generated", True, 1 / 100)
120 def sample_z(self, batch_size: int):
124 return torch.randn(batch_size, 100, device=self.device)
迈出训练一步
126 def step(self, batch: Any, batch_idx: BatchIndex):
设置模型状态
132 self.generator.train(self.mode.is_train)
133 self.discriminator.train(self.mode.is_train)
获取 MNIST 图片
136 data = batch[0].to(self.device)
在训练模式中增加步数
139 if self.mode.is_train:
140 tracker.add_global_step(len(data))
训练鉴别器
143 with monit.section("discriminator"):
获得鉴别器损失
145 loss = self.calc_discriminator_loss(data)
火车
148 if self.mode.is_train:
149 self.discriminator_optimizer.zero_grad()
150 loss.backward()
151 if batch_idx.is_last:
152 tracker.add('discriminator', self.discriminator)
153 self.discriminator_optimizer.step()
每隔一次训练发电机discriminator_k
156 if batch_idx.is_interval(self.discriminator_k):
157 with monit.section("generator"):
158 loss = self.calc_generator_loss(data.shape[0])
火车
161 if self.mode.is_train:
162 self.generator_optimizer.zero_grad()
163 loss.backward()
164 if batch_idx.is_last:
165 tracker.add('generator', self.generator)
166 self.generator_optimizer.step()
167
168 tracker.save()
计算鉴别器损失
170 def calc_discriminator_loss(self, data):
174 latent = self.sample_z(data.shape[0])
175 logits_true = self.discriminator(data)
176 logits_false = self.discriminator(self.generator(latent).detach())
177 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178 loss = loss_true + loss_false
日志的东西
181 tracker.add("loss.discriminator.true.", loss_true)
182 tracker.add("loss.discriminator.false.", loss_false)
183 tracker.add("loss.discriminator.", loss)
184
185 return loss
计算发电机损耗
187 def calc_generator_loss(self, batch_size: int):
191 latent = self.sample_z(batch_size)
192 generated_images = self.generator(latent)
193 logits = self.discriminator(generated_images)
194 loss = self.generator_loss(logits)
日志的东西
197 tracker.add('generated', generated_images[0:6])
198 tracker.add("loss.generator.", loss)
199
200 return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207 return transforms.Compose([
208 transforms.ToTensor(),
209 transforms.Normalize((0.5,), (0.5,))
210 ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215 opt_conf = OptimizerConfigs()
216 opt_conf.optimizer = 'Adam'
217 opt_conf.parameters = c.discriminator.parameters()
218 opt_conf.learning_rate = 2.5e-4
设置梯度第一时刻的指数衰减率0.5
非常重要。默认为0.9
失败。
222 opt_conf.betas = (0.5, 0.999)
223 return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228 opt_conf = OptimizerConfigs()
229 opt_conf.optimizer = 'Adam'
230 opt_conf.parameters = c.generator.parameters()
231 opt_conf.learning_rate = 2.5e-4
设置梯度第一时刻的指数衰减率0.5
非常重要。默认为0.9
失败。
235 opt_conf.betas = (0.5, 0.999)
236 return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246 conf = Configs()
247 experiment.create(name='mnist_gan', comment='test')
248 experiment.configs(conf,
249 {'label_smoothing': 0.01})
250 with experiment.start():
251 conf.run()
252
253
254if __name__ == '__main__':
255 main()