Cycle GAN

This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.

Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.

Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.

This file contains the model code as well as the training code. We also have a Google Colab notebook.

Open In Colab

35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_nn.helpers.device  import DeviceConfigs

The generator is a residual network.

55class GeneratorResNet(nn.Module):
60    def __init__(self, input_channels: int, n_residual_blocks: int):
61        super().__init__()

This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.

inplace=True in ReLU saves a little bit of memory.

69        out_features = 64
70        layers = [
71            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
72            nn.InstanceNorm2d(out_features),
73            nn.ReLU(inplace=True),
74        ]
75        in_features = out_features

We down-sample with two convolutions with stride of 2

79        for _ in range(2):
80            out_features *= 2
81            layers += [
82                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
83                nn.InstanceNorm2d(out_features),
84                nn.ReLU(inplace=True),
85            ]
86            in_features = out_features

We take this through n_residual_blocks . This module is defined below.

90        for _ in range(n_residual_blocks):
91            layers += [ResidualBlock(out_features)]

Then the resulting feature map is up-sampled to match the original image height and width.

95        for _ in range(2):
96            out_features //= 2
97            layers += [
98                nn.Upsample(scale_factor=2),
99                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
100                nn.InstanceNorm2d(out_features),
101                nn.ReLU(inplace=True),
102            ]
103            in_features = out_features

Finally we map the feature map to an RGB image

106        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

Create a sequential module with the layers

109        self.layers = nn.Sequential(*layers)

Initialize weights to

112        self.apply(weights_init_normal)
114    def forward(self, x):
115        return self.layers(x)

This is the residual block, with two convolution layers.

118class ResidualBlock(nn.Module):
123    def __init__(self, in_features: int):
124        super().__init__()
125        self.block = nn.Sequential(
126            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
127            nn.InstanceNorm2d(in_features),
128            nn.ReLU(inplace=True),
129            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
130            nn.InstanceNorm2d(in_features),
131            nn.ReLU(inplace=True),
132        )
134    def forward(self, x: torch.Tensor):
135        return x + self.block(x)

This is the discriminator.

138class Discriminator(nn.Module):
143    def __init__(self, input_shape: Tuple[int, int, int]):
144        super().__init__()
145        channels, height, width = input_shape

Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated

149        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
150
151        self.layers = nn.Sequential(

Each of these blocks will shrink the height and width by a factor of 2

153            DiscriminatorBlock(channels, 64, normalize=False),
154            DiscriminatorBlock(64, 128),
155            DiscriminatorBlock(128, 256),
156            DiscriminatorBlock(256, 512),

Zero pad on top and left to keep the output height and width same with the kernel

159            nn.ZeroPad2d((1, 0, 1, 0)),
160            nn.Conv2d(512, 1, kernel_size=4, padding=1)
161        )

Initialize weights to

164        self.apply(weights_init_normal)
166    def forward(self, img):
167        return self.layers(img)

This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.

It shrinks the height and width of the input feature map by half.

170class DiscriminatorBlock(nn.Module):
178    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
179        super().__init__()
180        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
181        if normalize:
182            layers.append(nn.InstanceNorm2d(out_filters))
183        layers.append(nn.LeakyReLU(0.2, inplace=True))
184        self.layers = nn.Sequential(*layers)
186    def forward(self, x: torch.Tensor):
187        return self.layers(x)

Initialize convolution layer weights to

190def weights_init_normal(m):
194    classname = m.__class__.__name__
195    if classname.find("Conv") != -1:
196        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

Load an image and change to RGB if in grey-scale.

199def load_image(path: str):
203    image = Image.open(path)
204    if image.mode != 'RGB':
205        image = Image.new("RGB", image.size).paste(image)
206
207    return image

Dataset to load images

210class ImageDataset(Dataset):

Download dataset and extract data

215    @staticmethod
216    def download(dataset_name: str):

URL

221        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

Download folder

223        root = lab.get_data_path() / 'cycle_gan'
224        if not root.exists():
225            root.mkdir(parents=True)

Download destination

227        archive = root / f'{dataset_name}.zip'

Download file (generally ~100MB)

229        download_file(url, archive)

Extract the archive

231        with zipfile.ZipFile(archive, 'r') as f:
232            f.extractall(root)

Initialize the dataset

  • dataset_name is the name of the dataset
  • transforms_ is the set of image transforms
  • mode is either train or test
234    def __init__(self, dataset_name: str, transforms_, mode: str):

Dataset path

243        root = lab.get_data_path() / 'cycle_gan' / dataset_name

Download if missing

245        if not root.exists():
246            self.download(dataset_name)

Image transforms

249        self.transform = transforms.Compose(transforms_)

Get image paths

252        path_a = root / f'{mode}A'
253        path_b = root / f'{mode}B'
254        self.files_a = sorted(str(f) for f in path_a.iterdir())
255        self.files_b = sorted(str(f) for f in path_b.iterdir())
257    def __getitem__(self, index):

Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.

261        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
262                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
264    def __len__(self):

Number of images in the dataset

266        return max(len(self.files_a), len(self.files_b))

Replay Buffer

Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.

The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.

This is done to reduce model oscillation.

269class ReplayBuffer:
283    def __init__(self, max_size: int = 50):
284        self.max_size = max_size
285        self.data = []

Add/retrieve an image

287    def push_and_pop(self, data: torch.Tensor):
289        data = data.detach()
290        res = []
291        for element in data:
292            if len(self.data) < self.max_size:
293                self.data.append(element)
294                res.append(element)
295            else:
296                if random.uniform(0, 1) > 0.5:
297                    i = random.randint(0, self.max_size - 1)
298                    res.append(self.data[i].clone())
299                    self.data[i] = element
300                else:
301                    res.append(element)
302        return torch.stack(res)

Configurations

305class Configs(BaseConfigs):

DeviceConfigs will pick a GPU if available

309    device: torch.device = DeviceConfigs()

Hyper-parameters

312    epochs: int = 200
313    dataset_name: str = 'monet2photo'
314    batch_size: int = 1
315
316    data_loader_workers = 8
317
318    learning_rate = 0.0002
319    adam_betas = (0.5, 0.999)
320    decay_start = 100

The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.

324    gan_loss = torch.nn.MSELoss()

L1 loss is used for cycle loss and identity loss

327    cycle_loss = torch.nn.L1Loss()
328    identity_loss = torch.nn.L1Loss()

Image dimensions

331    img_height = 256
332    img_width = 256
333    img_channels = 3

Number of residual blocks in the generator

336    n_residual_blocks = 9

Loss coefficients

339    cyclic_loss_coefficient = 10.0
340    identity_loss_coefficient = 5.
341
342    sample_interval = 500

Models

345    generator_xy: GeneratorResNet
346    generator_yx: GeneratorResNet
347    discriminator_x: Discriminator
348    discriminator_y: Discriminator

Optimizers

351    generator_optimizer: torch.optim.Adam
352    discriminator_optimizer: torch.optim.Adam

Learning rate schedules

355    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
356    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

Data loaders

359    dataloader: DataLoader
360    valid_dataloader: DataLoader

Generate samples from test set and save them

362    def sample_images(self, n: int):
364        batch = next(iter(self.valid_dataloader))
365        self.generator_xy.eval()
366        self.generator_yx.eval()
367        with torch.no_grad():
368            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
369            gen_y = self.generator_xy(data_x)
370            gen_x = self.generator_yx(data_y)

Arrange images along x-axis

373            data_x = make_grid(data_x, nrow=5, normalize=True)
374            data_y = make_grid(data_y, nrow=5, normalize=True)
375            gen_x = make_grid(gen_x, nrow=5, normalize=True)
376            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Arrange images along y-axis

379            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

Show samples

382        plot_image(image_grid)

Initialize models and data loaders

384    def initialize(self):
388        input_shape = (self.img_channels, self.img_height, self.img_width)

Create the models

391        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
392        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393        self.discriminator_x = Discriminator(input_shape).to(self.device)
394        self.discriminator_y = Discriminator(input_shape).to(self.device)

Create the optmizers

397        self.generator_optimizer = torch.optim.Adam(
398            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
399            lr=self.learning_rate, betas=self.adam_betas)
400        self.discriminator_optimizer = torch.optim.Adam(
401            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
402            lr=self.learning_rate, betas=self.adam_betas)

Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to at end of training.

407        decay_epochs = self.epochs - self.decay_start
408        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
409            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
410        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

Image transformations

414        transforms_ = [
415            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
416            transforms.RandomCrop((self.img_height, self.img_width)),
417            transforms.RandomHorizontalFlip(),
418            transforms.ToTensor(),
419            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
420        ]

Training data loader

423        self.dataloader = DataLoader(
424            ImageDataset(self.dataset_name, transforms_, 'train'),
425            batch_size=self.batch_size,
426            shuffle=True,
427            num_workers=self.data_loader_workers,
428        )

Validation data loader

431        self.valid_dataloader = DataLoader(
432            ImageDataset(self.dataset_name, transforms_, "test"),
433            batch_size=5,
434            shuffle=True,
435            num_workers=self.data_loader_workers,
436        )

Training

We aim to solve:

where, translates images from , translates images from , tests if images are from space, tests if images are from space, and

is the generative adversarial loss from the original GAN paper.

is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .

is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.

To solve , discriminators and should ascend on the gradient,

That is descend on negative log-likelihood loss.

In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,

We use least-squares for generators also. The generators should descend on the gradient,

We use generator_xy for and generator_yx for . We use discriminator_x for and discriminator_y for .

438    def run(self):

Replay buffers to keep generated samples

540        gen_x_buffer = ReplayBuffer()
541        gen_y_buffer = ReplayBuffer()

Loop through epochs

544        for epoch in monit.loop(self.epochs):

Loop through the dataset

546            for i, batch in monit.enum('Train', self.dataloader):

Move images to the device

548                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

true labels equal to

551                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
552                                         device=self.device, requires_grad=False)

false labels equal to

554                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
555                                           device=self.device, requires_grad=False)

Train the generators. This returns the generated images.

559                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

Train discriminators

562                self.optimize_discriminator(data_x, data_y,
563                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
564                                            true_labels, false_labels)

Save training statistics and increment the global step counter

567                tracker.save()
568                tracker.add_global_step(max(len(data_x), len(data_y)))

Save images at intervals

571                batches_done = epoch * len(self.dataloader) + i
572                if batches_done % self.sample_interval == 0:

Sample images

574                    self.sample_images(batches_done)

Update learning rates

577            self.generator_lr_scheduler.step()
578            self.discriminator_lr_scheduler.step()

New line

580            tracker.new_line()

Optimize the generators with identity, gan and cycle losses.

582    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

Change to training mode

588        self.generator_xy.train()
589        self.generator_yx.train()

Identity loss

594        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
595                         self.identity_loss(self.generator_xy(data_y), data_y))

Generate images and

598        gen_y = self.generator_xy(data_x)
599        gen_x = self.generator_yx(data_y)

GAN loss

604        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
605                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

Cycle loss

612        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
613                      self.cycle_loss(self.generator_xy(gen_x), data_y))

Total loss

616        loss_generator = (loss_gan +
617                          self.cyclic_loss_coefficient * loss_cycle +
618                          self.identity_loss_coefficient * loss_identity)

Take a step in the optimizer

621        self.generator_optimizer.zero_grad()
622        loss_generator.backward()
623        self.generator_optimizer.step()

Log losses

626        tracker.add({'loss.generator': loss_generator,
627                     'loss.generator.cycle': loss_cycle,
628                     'loss.generator.gan': loss_gan,
629                     'loss.generator.identity': loss_identity})

Return generated images

632        return gen_x, gen_y

Optimize the discriminators with gan loss.

634    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
635                               gen_x: torch.Tensor, gen_y: torch.Tensor,
636                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN Loss

649        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
650                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
651                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
652                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

Take a step in the optimizer

655        self.discriminator_optimizer.zero_grad()
656        loss_discriminator.backward()
657        self.discriminator_optimizer.step()

Log losses

660        tracker.add({'loss.discriminator': loss_discriminator})

Train Cycle GAN

663def train():

Create configurations

668    conf = Configs()

Create an experiment

670    experiment.create(name='cycle_gan')

Calculate configurations. It will calculate conf.run and all other configs required by it.

673    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
674    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

679    experiment.add_pytorch_models(get_modules(conf))

Start and watch the experiment

681    with experiment.start():

Run the training

683        conf.run()

Plot an image with matplotlib

686def plot_image(img: torch.Tensor):
690    from matplotlib import pyplot as plt

Move tensor to CPU

693    img = img.cpu()

Get min and max values of the image for normalization

695    img_min, img_max = img.min(), img.max()

Scale image values to be 0...1

697    img = (img - img_min) / (img_max - img_min + 1e-5)

We have to change the order of dimensions to HWC.

699    img = img.permute(1, 2, 0)

Show Image

701    plt.imshow(img)

We don't need axes

703    plt.axis('off')

Display

705    plt.show()

Evaluate trained Cycle GAN

708def evaluate():

Set the run UUID from the training run

713    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

Create configs object

715    conf = Configs()

Create experiment

717    experiment.create(name='cycle_gan_inference')

Load hyper parameters set for training

719    conf_dict = experiment.load_configs(trained_run_uuid)

Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx .

If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start

728    experiment.configs(conf, conf_dict)
729    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

734    experiment.add_pytorch_models(get_modules(conf))

Specify which run to load from. Loading will actually happen when you call experiment.start

737    experiment.load(trained_run_uuid)

Start the experiment

740    with experiment.start():

Image transformations

742        transforms_ = [
743            transforms.ToTensor(),
744            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
745        ]

Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name , if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs

751        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

Get an image from dataset

753        x_image = dataset[10]['x']

Display the image

755        plot_image(x_image)

Evaluation mode

758        conf.generator_xy.eval()
759        conf.generator_yx.eval()

We don't need gradients

762        with torch.no_grad():

Add batch dimension and move to the device we use

764            data = x_image.unsqueeze(0).to(conf.device)
765            generated_y = conf.generator_xy(data)

Display the generated image.

768        plot_image(generated_y[0].cpu())
769
770
771if __name__ == '__main__':
772    train()

evaluate()