これは、深層畳み込み生成型敵対的ネットワークを用いた教師なし表現学習のPyTorch実装です。
この実装は PyTorch DCGAN チュートリアルに基づいています。
15import torch.nn as nn
16
17from labml import experiment
18from labml.configs import calculate
19from labml_helpers.module import Module
20from labml_nn.gan.original.experiment import Configs
23class Generator(Module):
33 def __init__(self):
34 super().__init__()
入力は100チャンネル
36 self.layers = nn.Sequential(
これにより出力が得られます
38 nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
39 nn.BatchNorm2d(1024),
40 nn.ReLU(True),
これにより
42 nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
43 nn.BatchNorm2d(512),
44 nn.ReLU(True),
これにより
46 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
47 nn.BatchNorm2d(256),
48 nn.ReLU(True),
これにより
50 nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
51 nn.Tanh()
52 )
53
54 self.apply(_weights_init)
56 def forward(self, x):
[batch_size, 100]
形状を次のように変更 [batch_size, 100, 1, 1]
58 x = x.unsqueeze(-1).unsqueeze(-1)
59 x = self.layers(x)
60 return x
63class Discriminator(Module):
68 def __init__(self):
69 super().__init__()
入力は1チャンネルです
71 self.layers = nn.Sequential(
これにより
73 nn.Conv2d(1, 256, 4, 2, 1, bias=False),
74 nn.LeakyReLU(0.2, inplace=True),
これにより
76 nn.Conv2d(256, 512, 4, 2, 1, bias=False),
77 nn.BatchNorm2d(512),
78 nn.LeakyReLU(0.2, inplace=True),
これにより
80 nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
81 nn.BatchNorm2d(1024),
82 nn.LeakyReLU(0.2, inplace=True),
これにより
84 nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
85 )
86 self.apply(_weights_init)
88 def forward(self, x):
89 x = self.layers(x)
90 return x.view(x.shape[0], -1)
93def _weights_init(m):
94 classname = m.__class__.__name__
95 if classname.find('Conv') != -1:
96 nn.init.normal_(m.weight.data, 0.0, 0.02)
97 elif classname.find('BatchNorm') != -1:
98 nn.init.normal_(m.weight.data, 1.0, 0.02)
99 nn.init.constant_(m.bias.data, 0)
簡単なGAN実験をインポートして、ジェネレータとディスクリミネータのネットワークを変更します
104calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
105calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
108def main():
109 conf = Configs()
110 experiment.create(name='mnist_dcgan')
111 experiment.configs(conf,
112 {'discriminator': 'cnn',
113 'generator': 'cnn',
114 'label_smoothing': 0.01})
115 with experiment.start():
116 conf.run()
117
118
119if __name__ == '__main__':
120 main()