This is a PyTorch implementation of paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
This implementation is based on the PyTorch DCGAN Tutorial.
15import torch.nn as nn
16
17from labml import experiment
18from labml.configs import calculate
19from labml_nn.gan.original.experiment import Configs
This is similar to the de-convolutional network used for CelebA faces, but modified for MNIST images.
22class Generator(nn.Module):
32 def __init__(self):
33 super().__init__()
The input is with 100 channels
35 self.layers = nn.Sequential(
This gives output
37 nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
38 nn.BatchNorm2d(1024),
39 nn.ReLU(True),
This gives
41 nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
42 nn.BatchNorm2d(512),
43 nn.ReLU(True),
This gives
45 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
46 nn.BatchNorm2d(256),
47 nn.ReLU(True),
This gives
49 nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
50 nn.Tanh()
51 )
52
53 self.apply(_weights_init)
55 def forward(self, x):
Change from shape [batch_size, 100]
to [batch_size, 100, 1, 1]
57 x = x.unsqueeze(-1).unsqueeze(-1)
58 x = self.layers(x)
59 return x
62class Discriminator(nn.Module):
67 def __init__(self):
68 super().__init__()
The input is with one channel
70 self.layers = nn.Sequential(
This gives
72 nn.Conv2d(1, 256, 4, 2, 1, bias=False),
73 nn.LeakyReLU(0.2, inplace=True),
This gives
75 nn.Conv2d(256, 512, 4, 2, 1, bias=False),
76 nn.BatchNorm2d(512),
77 nn.LeakyReLU(0.2, inplace=True),
This gives
79 nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
80 nn.BatchNorm2d(1024),
81 nn.LeakyReLU(0.2, inplace=True),
This gives
83 nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
84 )
85 self.apply(_weights_init)
87 def forward(self, x):
88 x = self.layers(x)
89 return x.view(x.shape[0], -1)
92def _weights_init(m):
93 classname = m.__class__.__name__
94 if classname.find('Conv') != -1:
95 nn.init.normal_(m.weight.data, 0.0, 0.02)
96 elif classname.find('BatchNorm') != -1:
97 nn.init.normal_(m.weight.data, 1.0, 0.02)
98 nn.init.constant_(m.bias.data, 0)
We import the simple gan experiment and change the generator and discriminator networks
103calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
104calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
107def main():
108 conf = Configs()
109 experiment.create(name='mnist_dcgan')
110 experiment.configs(conf,
111 {'discriminator': 'cnn',
112 'generator': 'cnn',
113 'label_smoothing': 0.01})
114 with experiment.start():
115 conf.run()
116
117
118if __name__ == '__main__':
119 main()