This is a PyTorch implementation of paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
This implementation is based on the PyTorch DCGAN Tutorial.
15import torch.nn as nn
17from labml import experiment
18from labml.configs import calculate
19from labml_helpers.module import Module
20from labml_nn.gan.original.experiment import Configs
This is similar to the de-convolutional network used for CelebA faces, but modified for MNIST images.
23class Generator(Module):
33 def __init__(self):
34 super().__init__()
The input is with 100 channels
36 self.layers = nn.Sequential(
This gives output
38 nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
39 nn.BatchNorm2d(1024),
40 nn.ReLU(True),
This gives
42 nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
43 nn.BatchNorm2d(512),
44 nn.ReLU(True),
This gives
46 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
47 nn.BatchNorm2d(256),
48 nn.ReLU(True),
This gives
50 nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
51 nn.Tanh()
52 )
54 self.apply(_weights_init)
56 def forward(self, x):
Change from shape [batch_size, 100]
to [batch_size, 100, 1, 1]
58 x = x.unsqueeze(-1).unsqueeze(-1)
59 x = self.layers(x)
60 return x
63class Discriminator(Module):
68 def __init__(self):
69 super().__init__()
The input is with one channel
71 self.layers = nn.Sequential(
This gives
73 nn.Conv2d(1, 256, 4, 2, 1, bias=False),
74 nn.LeakyReLU(0.2, inplace=True),
This gives
76 nn.Conv2d(256, 512, 4, 2, 1, bias=False),
77 nn.BatchNorm2d(512),
78 nn.LeakyReLU(0.2, inplace=True),
This gives
80 nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
81 nn.BatchNorm2d(1024),
82 nn.LeakyReLU(0.2, inplace=True),
This gives
84 nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
85 )
86 self.apply(_weights_init)
88 def forward(self, x):
89 x = self.layers(x)
90 return x.view(x.shape[0], -1)
93def _weights_init(m):
94 classname = m.__class__.__name__
95 if classname.find('Conv') != -1:
96 nn.init.normal_(, 0.0, 0.02)
97 elif classname.find('BatchNorm') != -1:
98 nn.init.normal_(, 1.0, 0.02)
99 nn.init.constant_(, 0)
We import the simple gan experiment and change the generator and discriminator networks
104calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
105calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
108def main():
109 conf = Configs()
110 experiment.create(name='mnist_dcgan')
111 experiment.configs(conf,
112 {'discriminator': 'cnn',
113 'generator': 'cnn',
114 'label_smoothing': 0.01})
115 with experiment.start():
119if __name__ == '__main__':
120 main()