10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28 classname = m.__class__.__name__
29 if classname.find('Linear') != -1:
30 nn.init.normal_(m.weight.data, 0.0, 0.02)
31 elif classname.find('BatchNorm') != -1:
32 nn.init.normal_(m.weight.data, 1.0, 0.02)
33 nn.init.constant_(m.bias.data, 0)
36class Generator(Module):
44 def __init__(self):
45 super().__init__()
46 layer_sizes = [256, 512, 1024]
47 layers = []
48 d_prev = 100
49 for size in layer_sizes:
50 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51 d_prev = size
52
53 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55 self.apply(weights_init)
57 def forward(self, x):
58 return self.layers(x).view(x.shape[0], 1, 28, 28)
これには、LeakyReLU
アクティベーションを行うとサイズが小さくなる3つの線形レイヤーがあります。最後のレイヤーには、入力が本物か偽物かをロジットで示す出力が 1 つあります。確率は、そのシグモイドを計算することで求めることができます
61class Discriminator(Module):
70 def __init__(self):
71 super().__init__()
72 layer_sizes = [1024, 512, 256]
73 layers = []
74 d_prev = 28 * 28
75 for size in layer_sizes:
76 layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77 d_prev = size
78
79 self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80 self.apply(weights_init)
82 def forward(self, x):
83 return self.layers(x.view(x.shape[0], -1))
86class Configs(MNISTConfigs, TrainValidConfigs):
94 device: torch.device = DeviceConfigs()
95 dataset_transforms = 'mnist_gan_transforms'
96 epochs: int = 10
97
98 is_save_models = True
99 discriminator: Module = 'mlp'
100 generator: Module = 'mlp'
101 generator_optimizer: torch.optim.Adam
102 discriminator_optimizer: torch.optim.Adam
103 generator_loss: GeneratorLogitsLoss = 'original'
104 discriminator_loss: DiscriminatorLogitsLoss = 'original'
105 label_smoothing: float = 0.2
106 discriminator_k: int = 1
初期化
108 def init(self):
112 self.state_modules = []
113
114 hook_model_outputs(self.mode, self.generator, 'generator')
115 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116 tracker.set_scalar("loss.generator.*", True)
117 tracker.set_scalar("loss.discriminator.*", True)
118 tracker.set_image("generated", True, 1 / 100)
120 def sample_z(self, batch_size: int):
124 return torch.randn(batch_size, 100, device=self.device)
トレーニングの一歩を踏み出す
126 def step(self, batch: Any, batch_idx: BatchIndex):
モデル状態の設定
132 self.generator.train(self.mode.is_train)
133 self.discriminator.train(self.mode.is_train)
MNIST の画像を取得
136 data = batch[0].to(self.device)
トレーニングモードでのインクリメントステップ
139 if self.mode.is_train:
140 tracker.add_global_step(len(data))
ディスクリミネーターのトレーニング
143 with monit.section("discriminator"):
ディスクリミネーター損失を取得
145 loss = self.calc_discriminator_loss(data)
列車
148 if self.mode.is_train:
149 self.discriminator_optimizer.zero_grad()
150 loss.backward()
151 if batch_idx.is_last:
152 tracker.add('discriminator', self.discriminator)
153 self.discriminator_optimizer.step()
ジェネレータを毎回 1 回トレーニングします discriminator_k
156 if batch_idx.is_interval(self.discriminator_k):
157 with monit.section("generator"):
158 loss = self.calc_generator_loss(data.shape[0])
列車
161 if self.mode.is_train:
162 self.generator_optimizer.zero_grad()
163 loss.backward()
164 if batch_idx.is_last:
165 tracker.add('generator', self.generator)
166 self.generator_optimizer.step()
167
168 tracker.save()
ディスクリミネーター損失の計算
170 def calc_discriminator_loss(self, data):
174 latent = self.sample_z(data.shape[0])
175 logits_true = self.discriminator(data)
176 logits_false = self.discriminator(self.generator(latent).detach())
177 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178 loss = loss_true + loss_false
ログのもの
181 tracker.add("loss.discriminator.true.", loss_true)
182 tracker.add("loss.discriminator.false.", loss_false)
183 tracker.add("loss.discriminator.", loss)
184
185 return loss
発電機損失の計算
187 def calc_generator_loss(self, batch_size: int):
191 latent = self.sample_z(batch_size)
192 generated_images = self.generator(latent)
193 logits = self.discriminator(generated_images)
194 loss = self.generator_loss(logits)
ログのもの
197 tracker.add('generated', generated_images[0:6])
198 tracker.add("loss.generator.", loss)
199
200 return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207 return transforms.Compose([
208 transforms.ToTensor(),
209 transforms.Normalize((0.5,), (0.5,))
210 ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215 opt_conf = OptimizerConfigs()
216 opt_conf.optimizer = 'Adam'
217 opt_conf.parameters = c.discriminator.parameters()
218 opt_conf.learning_rate = 2.5e-4
勾配の最初の瞬間に指数減衰率を設定することは重要です。 0.5
0.9
デフォルトは失敗です。
222 opt_conf.betas = (0.5, 0.999)
223 return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228 opt_conf = OptimizerConfigs()
229 opt_conf.optimizer = 'Adam'
230 opt_conf.parameters = c.generator.parameters()
231 opt_conf.learning_rate = 2.5e-4
勾配の最初の瞬間に指数減衰率を設定することは重要です。 0.5
0.9
デフォルトは失敗です。
235 opt_conf.betas = (0.5, 0.999)
236 return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246 conf = Configs()
247 experiment.create(name='mnist_gan', comment='test')
248 experiment.configs(conf,
249 {'label_smoothing': 0.01})
250 with experiment.start():
251 conf.run()
252
253
254if __name__ == '__main__':
255 main()