MNISTによる敵対的生成ネットワークの実験

10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)

シンプルな MLP ジェネレータ

これには、LeakyReLU アクティベーションを行うとサイズが大きくなる3つの線形レイヤーがあります。最後のレイヤーにはアクティベーションがあります。

36class Generator(Module):
44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)
57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)

シンプルな MLP ディスクリミネーター

これには、LeakyReLU アクティベーションを行うとサイズが小さくなる3つの線形レイヤーがあります。最後のレイヤーには、入力が本物か偽物かをロジットで示す出力が 1 つあります。確率は、そのシグモイドを計算することで求めることができます

61class Discriminator(Module):
70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)
82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))

コンフィギュレーション

これにより、MNIST の構成が拡張され、データローダーやトレーニングおよび検証ループの構成が可能になり、実装が簡単になります。

86class Configs(MNISTConfigs, TrainValidConfigs):
94    device: torch.device = DeviceConfigs()
95    dataset_transforms = 'mnist_gan_transforms'
96    epochs: int = 10
97
98    is_save_models = True
99    discriminator: Module = 'mlp'
100    generator: Module = 'mlp'
101    generator_optimizer: torch.optim.Adam
102    discriminator_optimizer: torch.optim.Adam
103    generator_loss: GeneratorLogitsLoss = 'original'
104    discriminator_loss: DiscriminatorLogitsLoss = 'original'
105    label_smoothing: float = 0.2
106    discriminator_k: int = 1

初期化

108    def init(self):
112        self.state_modules = []
113
114        hook_model_outputs(self.mode, self.generator, 'generator')
115        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116        tracker.set_scalar("loss.generator.*", True)
117        tracker.set_scalar("loss.discriminator.*", True)
118        tracker.set_image("generated", True, 1 / 100)

120    def sample_z(self, batch_size: int):
124        return torch.randn(batch_size, 100, device=self.device)

トレーニングの一歩を踏み出す

126    def step(self, batch: Any, batch_idx: BatchIndex):

モデル状態の設定

132        self.generator.train(self.mode.is_train)
133        self.discriminator.train(self.mode.is_train)

MNIST の画像を取得

136        data = batch[0].to(self.device)

トレーニングモードでのインクリメントステップ

139        if self.mode.is_train:
140            tracker.add_global_step(len(data))

ディスクリミネーターのトレーニング

143        with monit.section("discriminator"):

ディスクリミネーター損失を取得

145            loss = self.calc_discriminator_loss(data)

列車

148            if self.mode.is_train:
149                self.discriminator_optimizer.zero_grad()
150                loss.backward()
151                if batch_idx.is_last:
152                    tracker.add('discriminator', self.discriminator)
153                self.discriminator_optimizer.step()

ジェネレータを毎回 1 回トレーニングします discriminator_k

156        if batch_idx.is_interval(self.discriminator_k):
157            with monit.section("generator"):
158                loss = self.calc_generator_loss(data.shape[0])

列車

161                if self.mode.is_train:
162                    self.generator_optimizer.zero_grad()
163                    loss.backward()
164                    if batch_idx.is_last:
165                        tracker.add('generator', self.generator)
166                    self.generator_optimizer.step()
167
168        tracker.save()

ディスクリミネーター損失の計算

170    def calc_discriminator_loss(self, data):
174        latent = self.sample_z(data.shape[0])
175        logits_true = self.discriminator(data)
176        logits_false = self.discriminator(self.generator(latent).detach())
177        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178        loss = loss_true + loss_false

ログのもの

181        tracker.add("loss.discriminator.true.", loss_true)
182        tracker.add("loss.discriminator.false.", loss_false)
183        tracker.add("loss.discriminator.", loss)
184
185        return loss

発電機損失の計算

187    def calc_generator_loss(self, batch_size: int):
191        latent =  self.sample_z(batch_size)
192        generated_images = self.generator(latent)
193        logits = self.discriminator(generated_images)
194        loss = self.generator_loss(logits)

ログのもの

197        tracker.add('generated', generated_images[0:6])
198        tracker.add("loss.generator.", loss)
199
200        return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207    return transforms.Compose([
208        transforms.ToTensor(),
209        transforms.Normalize((0.5,), (0.5,))
210    ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215    opt_conf = OptimizerConfigs()
216    opt_conf.optimizer = 'Adam'
217    opt_conf.parameters = c.discriminator.parameters()
218    opt_conf.learning_rate = 2.5e-4

勾配の最初の瞬間に指数減衰率を設定することは重要です。 0.5 0.9 デフォルトは失敗です。

222    opt_conf.betas = (0.5, 0.999)
223    return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228    opt_conf = OptimizerConfigs()
229    opt_conf.optimizer = 'Adam'
230    opt_conf.parameters = c.generator.parameters()
231    opt_conf.learning_rate = 2.5e-4

勾配の最初の瞬間に指数減衰率を設定することは重要です。 0.5 0.9 デフォルトは失敗です。

235    opt_conf.betas = (0.5, 0.999)
236    return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246    conf = Configs()
247    experiment.create(name='mnist_gan', comment='test')
248    experiment.configs(conf,
249                       {'label_smoothing': 0.01})
250    with experiment.start():
251        conf.run()
252
253
254if __name__ == '__main__':
255    main()