Cycle GAN

This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

I’ve taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.

Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.

Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.

This file contains the model code as well as the training code. We also have a Google Colab notebook.

Open In Colab View Run

36import itertools
37import random
38import zipfile
39from typing import Tuple
40
41import torch
42import torch.nn as nn
43import torchvision.transforms as transforms
44from PIL import Image
45from torch.utils.data import DataLoader, Dataset
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module

The generator is a residual network.

56class GeneratorResNet(Module):
61    def __init__(self, input_channels: int, n_residual_blocks: int):
62        super().__init__()

This first block runs a $7\times7$ convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of $3$. Reflection padding is used because it gives better image quality at edges.

inplace=True in ReLU saves a little bit of memory.

70        out_features = 64
71        layers = [
72            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73            nn.InstanceNorm2d(out_features),
74            nn.ReLU(inplace=True),
75        ]
76        in_features = out_features

We down-sample with two $3 \times 3$ convolutions with stride of 2

80        for _ in range(2):
81            out_features *= 2
82            layers += [
83                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84                nn.InstanceNorm2d(out_features),
85                nn.ReLU(inplace=True),
86            ]
87            in_features = out_features

We take this through n_residual_blocks. This module is defined below.

91        for _ in range(n_residual_blocks):
92            layers += [ResidualBlock(out_features)]

Then the resulting feature map is up-sampled to match the original image height and width.

96        for _ in range(2):
97            out_features //= 2
98            layers += [
99                nn.Upsample(scale_factor=2),
100                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101                nn.InstanceNorm2d(out_features),
102                nn.ReLU(inplace=True),
103            ]
104            in_features = out_features

Finally we map the feature map to an RGB image

107        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

Create a sequential module with the layers

110        self.layers = nn.Sequential(*layers)

Initialize weights to $\mathcal{N}(0, 0.2)$

113        self.apply(weights_init_normal)
115    def __call__(self, x):
116        return self.layers(x)

This is the residual block, with two convolution layers.

119class ResidualBlock(Module):
124    def __init__(self, in_features: int):
125        super().__init__()
126        self.block = nn.Sequential(
127            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128            nn.InstanceNorm2d(in_features),
129            nn.ReLU(inplace=True),
130            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131            nn.InstanceNorm2d(in_features),
132            nn.ReLU(inplace=True),
133        )
135    def __call__(self, x: torch.Tensor):
136        return x + self.block(x)

This is the discriminator.

139class Discriminator(Module):
144    def __init__(self, input_shape: Tuple[int, int, int]):
145        super().__init__()
146        channels, height, width = input_shape

Output of the discriminator is also a map of probabilities* whether each region of the image is real or generated

150        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152        self.layers = nn.Sequential(

Each of these blocks will shrink the height and width by a factor of 2

154            DiscriminatorBlock(channels, 64, normalize=False),
155            DiscriminatorBlock(64, 128),
156            DiscriminatorBlock(128, 256),
157            DiscriminatorBlock(256, 512),

Zero pad on top and left to keep the output height and width same with the $4 \times 4$ kernel

160            nn.ZeroPad2d((1, 0, 1, 0)),
161            nn.Conv2d(512, 1, kernel_size=4, padding=1)
162        )

Initialize weights to $\mathcal{N}(0, 0.2)$

165        self.apply(weights_init_normal)
167    def forward(self, img):
168        return self.layers(img)

This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.

It shrinks the height and width of the input feature map by half.

171class DiscriminatorBlock(Module):
179    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180        super().__init__()
181        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182        if normalize:
183            layers.append(nn.InstanceNorm2d(out_filters))
184        layers.append(nn.LeakyReLU(0.2, inplace=True))
185        self.layers = nn.Sequential(*layers)
187    def __call__(self, x: torch.Tensor):
188        return self.layers(x)

Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$

191def weights_init_normal(m):
195    classname = m.__class__.__name__
196    if classname.find("Conv") != -1:
197        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

Load an image and change to RGB if in grey-scale.

200def load_image(path: str):
204    image = Image.open(path)
205    if image.mode != 'RGB':
206        image = Image.new("RGB", image.size).paste(image)
207
208    return image

Dataset to load images

211class ImageDataset(Dataset):

Download dataset and extract data

216    @staticmethod
217    def download(dataset_name: str):

URL

222        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

Download folder

224        root = lab.get_data_path() / 'cycle_gan'
225        if not root.exists():
226            root.mkdir(parents=True)

Download destination

228        archive = root / f'{dataset_name}.zip'

Download file (generally ~100MB)

230        download_file(url, archive)

Extract the archive

232        with zipfile.ZipFile(archive, 'r') as f:
233            f.extractall(root)

Initialize the dataset

  • dataset_name is the name of the dataset
  • transforms_ is the set of image transforms
  • mode is either train or test
235    def __init__(self, dataset_name: str, transforms_, mode: str):

Dataset path

244        root = lab.get_data_path() / 'cycle_gan' / dataset_name

Download if missing

246        if not root.exists():
247            self.download(dataset_name)

Image transforms

250        self.transform = transforms.Compose(transforms_)

Get image paths

253        path_a = root / f'{mode}A'
254        path_b = root / f'{mode}B'
255        self.files_a = sorted(str(f) for f in path_a.iterdir())
256        self.files_b = sorted(str(f) for f in path_b.iterdir())
258    def __getitem__(self, index):

Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.

262        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265    def __len__(self):

Number of images in the dataset

267        return max(len(self.files_a), len(self.files_b))

Replay Buffer

Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.

The replay buffer returns the newly added image with a probability of $0.5$. Otherwise, it sends an older generated image and replaces the older image with the newly generated image.

This is done to reduce model oscillation.

270class ReplayBuffer:
284    def __init__(self, max_size: int = 50):
285        self.max_size = max_size
286        self.data = []

Add/retrieve an image

288    def push_and_pop(self, data: torch.Tensor):
290        data = data.detach()
291        res = []
292        for element in data:
293            if len(self.data) < self.max_size:
294                self.data.append(element)
295                res.append(element)
296            else:
297                if random.uniform(0, 1) > 0.5:
298                    i = random.randint(0, self.max_size - 1)
299                    res.append(self.data[i].clone())
300                    self.data[i] = element
301                else:
302                    res.append(element)
303        return torch.stack(res)

Configurations

306class Configs(BaseConfigs):

DeviceConfigs will pick a GPU if available

310    device: torch.device = DeviceConfigs()

Hyper-parameters

313    epochs: int = 200
314    dataset_name: str = 'monet2photo'
315    batch_size: int = 1
316
317    data_loader_workers = 8
318
319    learning_rate = 0.0002
320    adam_betas = (0.5, 0.999)
321    decay_start = 100

The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.

325    gan_loss = torch.nn.MSELoss()

L1 loss is used for cycle loss and identity loss

328    cycle_loss = torch.nn.L1Loss()
329    identity_loss = torch.nn.L1Loss()

Image dimensions

332    img_height = 256
333    img_width = 256
334    img_channels = 3

Number of residual blocks in the generator

337    n_residual_blocks = 9

Loss coefficients

340    cyclic_loss_coefficient = 10.0
341    identity_loss_coefficient = 5.
342
343    sample_interval = 500

Models

346    generator_xy: GeneratorResNet
347    generator_yx: GeneratorResNet
348    discriminator_x: Discriminator
349    discriminator_y: Discriminator

Optimizers

352    generator_optimizer: torch.optim.Adam
353    discriminator_optimizer: torch.optim.Adam

Learning rate schedules

356    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

Data loaders

360    dataloader: DataLoader
361    valid_dataloader: DataLoader

Generate samples from test set and save them

363    def sample_images(self, n: int):
365        batch = next(iter(self.valid_dataloader))
366        self.generator_xy.eval()
367        self.generator_yx.eval()
368        with torch.no_grad():
369            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370            gen_y = self.generator_xy(data_x)
371            gen_x = self.generator_yx(data_y)

Arrange images along x-axis

374            data_x = make_grid(data_x, nrow=5, normalize=True)
375            data_y = make_grid(data_y, nrow=5, normalize=True)
376            gen_x = make_grid(gen_x, nrow=5, normalize=True)
377            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Arrange images along y-axis

380            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

Show samples

383        plot_image(image_grid)

Initialize models and data loaders

385    def initialize(self):
389        input_shape = (self.img_channels, self.img_height, self.img_width)

Create the models

392        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.discriminator_x = Discriminator(input_shape).to(self.device)
395        self.discriminator_y = Discriminator(input_shape).to(self.device)

Create the optmizers

398        self.generator_optimizer = torch.optim.Adam(
399            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400            lr=self.learning_rate, betas=self.adam_betas)
401        self.discriminator_optimizer = torch.optim.Adam(
402            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403            lr=self.learning_rate, betas=self.adam_betas)

Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to $0$ at end of training.

408        decay_epochs = self.epochs - self.decay_start
409        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

Image transformations

415        transforms_ = [
416            transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
417            transforms.RandomCrop((self.img_height, self.img_width)),
418            transforms.RandomHorizontalFlip(),
419            transforms.ToTensor(),
420            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421        ]

Training data loader

424        self.dataloader = DataLoader(
425            ImageDataset(self.dataset_name, transforms_, 'train'),
426            batch_size=self.batch_size,
427            shuffle=True,
428            num_workers=self.data_loader_workers,
429        )

Validation data loader

432        self.valid_dataloader = DataLoader(
433            ImageDataset(self.dataset_name, transforms_, "test"),
434            batch_size=5,
435            shuffle=True,
436            num_workers=self.data_loader_workers,
437        )

Training

We aim to solve:

where, $G$ translates images from $X \rightarrow Y$, $F$ translates images from $Y \rightarrow X$, $D_X$ tests if images are from $X$ space, $D_Y$ tests if images are from $Y$ space, and

$\mathcal{L}_{GAN}$ is the generative adversarial loss from the original GAN paper.

$\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$, and $G(F(y))$ to be similar to $y$. Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss $G(x)$ could generate anything that’s from the distribution of $Y$. Now it needs to generate something from the distribution of $Y$ but still has properties of $x$, so that $F(G(x)$ can re-generate something like $x$.

$\mathcal{L}_{cyc}$ is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.

To solve $G^{*}, F^{*}$, discriminators $D_X$ and $D_Y$ should ascend on the gradient, That is descend on negative log-likelihood loss.

In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,

We use least-squares for generators also. The generators should descend on the gradient,

We use generator_xy for $G$ and generator_yx$ for $F$. We usediscriminator_x$ for $D_X$ and discriminator_y for $D_Y$.

439    def run(self):

Replay buffers to keep generated samples

536        gen_x_buffer = ReplayBuffer()
537        gen_y_buffer = ReplayBuffer()

Loop through epochs

540        for epoch in monit.loop(self.epochs):

Loop through the dataset

542            for i, batch in monit.enum('Train', self.dataloader):

Move images to the device

544                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

true labels equal to $1$

547                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
548                                         device=self.device, requires_grad=False)

false labels equal to $0$

550                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
551                                           device=self.device, requires_grad=False)

Train the generators. This returns the generated images.

555                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

Train discriminators

558                self.optimize_discriminator(data_x, data_y,
559                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
560                                            true_labels, false_labels)

Save training statistics and increment the global step counter

563                tracker.save()
564                tracker.add_global_step(max(len(data_x), len(data_y)))

Save images at intervals

567                batches_done = epoch * len(self.dataloader) + i
568                if batches_done % self.sample_interval == 0:

Save models when sampling images

570                    experiment.save_checkpoint()

Sample images

572                    self.sample_images(batches_done)

Update learning rates

575            self.generator_lr_scheduler.step()
576            self.discriminator_lr_scheduler.step()

New line

578            tracker.new_line()

Optimize the generators with identity, gan and cycle losses.

580    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

Change to training mode

586        self.generator_xy.train()
587        self.generator_yx.train()

Identity loss

592        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
593                         self.identity_loss(self.generator_xy(data_y), data_y))

Generate images $G(x)$ and $F(y)$

596        gen_y = self.generator_xy(data_x)
597        gen_x = self.generator_yx(data_y)

GAN loss

602        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
603                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

Cycle loss

610        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
611                      self.cycle_loss(self.generator_xy(gen_x), data_y))

Total loss

614        loss_generator = (loss_gan +
615                          self.cyclic_loss_coefficient * loss_cycle +
616                          self.identity_loss_coefficient * loss_identity)

Take a step in the optimizer

619        self.generator_optimizer.zero_grad()
620        loss_generator.backward()
621        self.generator_optimizer.step()

Log losses

624        tracker.add({'loss.generator': loss_generator,
625                     'loss.generator.cycle': loss_cycle,
626                     'loss.generator.gan': loss_gan,
627                     'loss.generator.identity': loss_identity})

Return generated images

630        return gen_x, gen_y

Optimize the discriminators with gan loss.

632    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
633                               gen_x: torch.Tensor, gen_y: torch.Tensor,
634                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN Loss

645        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
646                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
647                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
648                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

Take a step in the optimizer

651        self.discriminator_optimizer.zero_grad()
652        loss_discriminator.backward()
653        self.discriminator_optimizer.step()

Log losses

656        tracker.add({'loss.discriminator': loss_discriminator})

Train Cycle GAN

659def train():

Create configurations

664    conf = Configs()

Create an experiment

666    experiment.create(name='cycle_gan')

Calculate configurations. It will calculate conf.run and all other configs required by it.

669    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
670    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf. You can also specify a custom dictionary of models.

675    experiment.add_pytorch_models(get_modules(conf))

Start and watch the experiment

677    with experiment.start():

Run the training

679        conf.run()

Plot an image with matplotlib

682def plot_image(img: torch.Tensor):
686    from matplotlib import pyplot as plt

Move tensor to CPU

689    img = img.cpu()

Get min and max values of the image for normalization

691    img_min, img_max = img.min(), img.max()

Scale image values to be [0…1]

693    img = (img - img_min) / (img_max - img_min + 1e-5)

We have to change the order of dimensions to HWC.

695    img = img.permute(1, 2, 0)

Show Image

697    plt.imshow(img)

We don’t need axes

699    plt.axis('off')

Display

701    plt.show()

Evaluate trained Cycle GAN

704def evaluate():

Set the run UUID from the training run

709    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

Create configs object

711    conf = Configs()

Create experiment

713    experiment.create(name='cycle_gan_inference')

Load hyper parameters set for training

715    conf_dict = experiment.load_configs(trained_run_uuid)

Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx.

If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start

724    experiment.configs(conf, conf_dict)
725    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf. You can also specify a custom dictionary of models.

730    experiment.add_pytorch_models(get_modules(conf))

Specify which run to load from. Loading will actually happen when you call experiment.start

733    experiment.load(trained_run_uuid)

Start the experiment

736    with experiment.start():

Image transformations

738        transforms_ = [
739            transforms.ToTensor(),
740            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
741        ]

Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name, if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs

747        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

Get an image from dataset

749        x_image = dataset[10]['x']

Display the image

751        plot_image(x_image)

Evaluation mode

754        conf.generator_xy.eval()
755        conf.generator_yx.eval()

We don’t need gradients

758        with torch.no_grad():

Add batch dimension and move to the device we use

760            data = x_image.unsqueeze(0).to(conf.device)
761            generated_y = conf.generator_xy(data)

Display the generated image.

764        plot_image(generated_y[0].cpu())
765
766
767if __name__ == '__main__':
768    train()

evaluate()