深度卷积生成对抗网络 (DCGAN)

这是 PyTorch 实现的论文《使用深度卷积生成对抗网络进行无监督表示学习》。

此实现基于 PyTorch DCGAN 教程

15import torch.nn as nn
16
17from labml import experiment
18from labml.configs import calculate
19from labml_helpers.module import Module
20from labml_nn.gan.original.experiment import Configs

卷积生成器网络

这类似于用于 CeleBA 人脸的反卷积网络,但针对 MNIST 图像进行了修改。

DCGan Architecture

23class Generator(Module):
33    def __init__(self):
34        super().__init__()

输入有 100 个通道

36        self.layers = nn.Sequential(

这给出了输出

38            nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
39            nn.BatchNorm2d(1024),
40            nn.ReLU(True),

这给了

42            nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
43            nn.BatchNorm2d(512),
44            nn.ReLU(True),

这给了

46            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
47            nn.BatchNorm2d(256),
48            nn.ReLU(True),

这给了

50            nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
51            nn.Tanh()
52        )
53
54        self.apply(_weights_init)
56    def forward(self, x):

从形状改[batch_size, 100][batch_size, 100, 1, 1]

58        x = x.unsqueeze(-1).unsqueeze(-1)
59        x = self.layers(x)
60        return x

卷积鉴别器网络

63class Discriminator(Module):
68    def __init__(self):
69        super().__init__()

输入使用一个通道

71        self.layers = nn.Sequential(

这给了

73            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
74            nn.LeakyReLU(0.2, inplace=True),

这给了

76            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
77            nn.BatchNorm2d(512),
78            nn.LeakyReLU(0.2, inplace=True),

这给了

80            nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
81            nn.BatchNorm2d(1024),
82            nn.LeakyReLU(0.2, inplace=True),

这给了

84            nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
85        )
86        self.apply(_weights_init)
88    def forward(self, x):
89        x = self.layers(x)
90        return x.view(x.shape[0], -1)
93def _weights_init(m):
94    classname = m.__class__.__name__
95    if classname.find('Conv') != -1:
96        nn.init.normal_(m.weight.data, 0.0, 0.02)
97    elif classname.find('BatchNorm') != -1:
98        nn.init.normal_(m.weight.data, 1.0, 0.02)
99        nn.init.constant_(m.bias.data, 0)

我们导入了简单的 gan 实验并更改了生成器和鉴别器网络

104calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
105calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
108def main():
109    conf = Configs()
110    experiment.create(name='mnist_dcgan')
111    experiment.configs(conf,
112                       {'discriminator': 'cnn',
113                        'generator': 'cnn',
114                        'label_smoothing': 0.01})
115    with experiment.start():
116        conf.run()
117
118
119if __name__ == '__main__':
120    main()