使用 MNIST 进行生成对抗网络实验

10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)

简单的 MLP 生成器

它有三个线性层,随着LeakyReLU 激活的大小不断增加。最后一层已激活。

36class Generator(Module):
44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)
57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)

简单的 MLP 鉴别器

它有三个线性层,随着LeakyReLU 激活的大小逐渐减小。最后一层有一个单独的输出,它给出了输入是真实还是假的 logit。你可以通过计算它的乙状结肠来获得概率。

61class Discriminator(Module):
70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)
82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))

配置

这扩展了 MNIST 配置以获取数据加载器以及训练和验证循环配置,从而简化了我们的实现。

86class Configs(MNISTConfigs, TrainValidConfigs):
94    device: torch.device = DeviceConfigs()
95    dataset_transforms = 'mnist_gan_transforms'
96    epochs: int = 10
97
98    is_save_models = True
99    discriminator: Module = 'mlp'
100    generator: Module = 'mlp'
101    generator_optimizer: torch.optim.Adam
102    discriminator_optimizer: torch.optim.Adam
103    generator_loss: GeneratorLogitsLoss = 'original'
104    discriminator_loss: DiscriminatorLogitsLoss = 'original'
105    label_smoothing: float = 0.2
106    discriminator_k: int = 1

初始化

108    def init(self):
112        self.state_modules = []
113
114        hook_model_outputs(self.mode, self.generator, 'generator')
115        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116        tracker.set_scalar("loss.generator.*", True)
117        tracker.set_scalar("loss.discriminator.*", True)
118        tracker.set_image("generated", True, 1 / 100)

120    def sample_z(self, batch_size: int):
124        return torch.randn(batch_size, 100, device=self.device)

迈出训练一步

126    def step(self, batch: Any, batch_idx: BatchIndex):

设置模型状态

132        self.generator.train(self.mode.is_train)
133        self.discriminator.train(self.mode.is_train)

获取 MNIST 图片

136        data = batch[0].to(self.device)

在训练模式中增加步数

139        if self.mode.is_train:
140            tracker.add_global_step(len(data))

训练鉴别器

143        with monit.section("discriminator"):

获得鉴别器损失

145            loss = self.calc_discriminator_loss(data)

火车

148            if self.mode.is_train:
149                self.discriminator_optimizer.zero_grad()
150                loss.backward()
151                if batch_idx.is_last:
152                    tracker.add('discriminator', self.discriminator)
153                self.discriminator_optimizer.step()

每隔一次训练发电机discriminator_k

156        if batch_idx.is_interval(self.discriminator_k):
157            with monit.section("generator"):
158                loss = self.calc_generator_loss(data.shape[0])

火车

161                if self.mode.is_train:
162                    self.generator_optimizer.zero_grad()
163                    loss.backward()
164                    if batch_idx.is_last:
165                        tracker.add('generator', self.generator)
166                    self.generator_optimizer.step()
167
168        tracker.save()

计算鉴别器损失

170    def calc_discriminator_loss(self, data):
174        latent = self.sample_z(data.shape[0])
175        logits_true = self.discriminator(data)
176        logits_false = self.discriminator(self.generator(latent).detach())
177        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178        loss = loss_true + loss_false

日志的东西

181        tracker.add("loss.discriminator.true.", loss_true)
182        tracker.add("loss.discriminator.false.", loss_false)
183        tracker.add("loss.discriminator.", loss)
184
185        return loss

计算发电机损耗

187    def calc_generator_loss(self, batch_size: int):
191        latent =  self.sample_z(batch_size)
192        generated_images = self.generator(latent)
193        logits = self.discriminator(generated_images)
194        loss = self.generator_loss(logits)

日志的东西

197        tracker.add('generated', generated_images[0:6])
198        tracker.add("loss.generator.", loss)
199
200        return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207    return transforms.Compose([
208        transforms.ToTensor(),
209        transforms.Normalize((0.5,), (0.5,))
210    ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215    opt_conf = OptimizerConfigs()
216    opt_conf.optimizer = 'Adam'
217    opt_conf.parameters = c.discriminator.parameters()
218    opt_conf.learning_rate = 2.5e-4

设置梯度第一时刻的指数衰减率0.5 非常重要。默认为0.9 失败。

222    opt_conf.betas = (0.5, 0.999)
223    return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228    opt_conf = OptimizerConfigs()
229    opt_conf.optimizer = 'Adam'
230    opt_conf.parameters = c.generator.parameters()
231    opt_conf.learning_rate = 2.5e-4

设置梯度第一时刻的指数衰减率0.5 非常重要。默认为0.9 失败。

235    opt_conf.betas = (0.5, 0.999)
236    return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246    conf = Configs()
247    experiment.create(name='mnist_gan', comment='test')
248    experiment.configs(conf,
249                       {'label_smoothing': 0.01})
250    with experiment.start():
251        conf.run()
252
253
254if __name__ == '__main__':
255    main()