深層畳み込み型敵対的生成ネットワーク (DCGAN)

これは、深層畳み込み生成型敵対的ネットワークを用いた教師なし表現学習のPyTorch実装です

この実装は PyTorch DCGAN チュートリアルに基づいています。

15import torch.nn as nn
16
17from labml import experiment
18from labml.configs import calculate
19from labml_helpers.module import Module
20from labml_nn.gan.original.experiment import Configs

畳み込みジェネレータネットワーク

これは CeleBA フェースに使用されているデコンボリューショナルネットワークに似ていますが、MNIST イメージ用に変更されています。

DCGan Architecture

23class Generator(Module):
33    def __init__(self):
34        super().__init__()

入力は100チャンネル

36        self.layers = nn.Sequential(

これにより出力が得られます

38            nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
39            nn.BatchNorm2d(1024),
40            nn.ReLU(True),

これにより

42            nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
43            nn.BatchNorm2d(512),
44            nn.ReLU(True),

これにより

46            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
47            nn.BatchNorm2d(256),
48            nn.ReLU(True),

これにより

50            nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
51            nn.Tanh()
52        )
53
54        self.apply(_weights_init)
56    def forward(self, x):

[batch_size, 100] 形状を次のように変更 [batch_size, 100, 1, 1]

58        x = x.unsqueeze(-1).unsqueeze(-1)
59        x = self.layers(x)
60        return x

畳み込み弁別ネットワーク

63class Discriminator(Module):
68    def __init__(self):
69        super().__init__()

入力は1チャンネルです

71        self.layers = nn.Sequential(

これにより

73            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
74            nn.LeakyReLU(0.2, inplace=True),

これにより

76            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
77            nn.BatchNorm2d(512),
78            nn.LeakyReLU(0.2, inplace=True),

これにより

80            nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
81            nn.BatchNorm2d(1024),
82            nn.LeakyReLU(0.2, inplace=True),

これにより

84            nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
85        )
86        self.apply(_weights_init)
88    def forward(self, x):
89        x = self.layers(x)
90        return x.view(x.shape[0], -1)
93def _weights_init(m):
94    classname = m.__class__.__name__
95    if classname.find('Conv') != -1:
96        nn.init.normal_(m.weight.data, 0.0, 0.02)
97    elif classname.find('BatchNorm') != -1:
98        nn.init.normal_(m.weight.data, 1.0, 0.02)
99        nn.init.constant_(m.bias.data, 0)

簡単なGAN実験をインポートして、ジェネレータとディスクリミネータのネットワークを変更します

104calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
105calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
108def main():
109    conf = Configs()
110    experiment.create(name='mnist_dcgan')
111    experiment.configs(conf,
112                       {'discriminator': 'cnn',
113                        'generator': 'cnn',
114                        'label_smoothing': 0.01})
115    with experiment.start():
116        conf.run()
117
118
119if __name__ == '__main__':
120    main()