Cycle GAN

This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.

Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.

Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.

This file contains the model code as well as the training code. We also have a Google Colab notebook.

Open In Colab

35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module

The generator is a residual network.

56class GeneratorResNet(Module):
61    def __init__(self, input_channels: int, n_residual_blocks: int):
62        super().__init__()

This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.

inplace=True in ReLU saves a little bit of memory.

70        out_features = 64
71        layers = [
72            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73            nn.InstanceNorm2d(out_features),
74            nn.ReLU(inplace=True),
75        ]
76        in_features = out_features

We down-sample with two convolutions with stride of 2

80        for _ in range(2):
81            out_features *= 2
82            layers += [
83                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84                nn.InstanceNorm2d(out_features),
85                nn.ReLU(inplace=True),
86            ]
87            in_features = out_features

We take this through n_residual_blocks . This module is defined below.

91        for _ in range(n_residual_blocks):
92            layers += [ResidualBlock(out_features)]

Then the resulting feature map is up-sampled to match the original image height and width.

96        for _ in range(2):
97            out_features //= 2
98            layers += [
99                nn.Upsample(scale_factor=2),
100                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101                nn.InstanceNorm2d(out_features),
102                nn.ReLU(inplace=True),
103            ]
104            in_features = out_features

Finally we map the feature map to an RGB image

107        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

Create a sequential module with the layers

110        self.layers = nn.Sequential(*layers)

Initialize weights to

113        self.apply(weights_init_normal)
115    def forward(self, x):
116        return self.layers(x)

This is the residual block, with two convolution layers.

119class ResidualBlock(Module):
124    def __init__(self, in_features: int):
125        super().__init__()
126        self.block = nn.Sequential(
127            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128            nn.InstanceNorm2d(in_features),
129            nn.ReLU(inplace=True),
130            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131            nn.InstanceNorm2d(in_features),
132            nn.ReLU(inplace=True),
133        )
135    def forward(self, x: torch.Tensor):
136        return x + self.block(x)

This is the discriminator.

139class Discriminator(Module):
144    def __init__(self, input_shape: Tuple[int, int, int]):
145        super().__init__()
146        channels, height, width = input_shape

Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated

150        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152        self.layers = nn.Sequential(

Each of these blocks will shrink the height and width by a factor of 2

154            DiscriminatorBlock(channels, 64, normalize=False),
155            DiscriminatorBlock(64, 128),
156            DiscriminatorBlock(128, 256),
157            DiscriminatorBlock(256, 512),

Zero pad on top and left to keep the output height and width same with the kernel

160            nn.ZeroPad2d((1, 0, 1, 0)),
161            nn.Conv2d(512, 1, kernel_size=4, padding=1)
162        )

Initialize weights to

165        self.apply(weights_init_normal)
167    def forward(self, img):
168        return self.layers(img)

This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.

It shrinks the height and width of the input feature map by half.

171class DiscriminatorBlock(Module):
179    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180        super().__init__()
181        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182        if normalize:
183            layers.append(nn.InstanceNorm2d(out_filters))
184        layers.append(nn.LeakyReLU(0.2, inplace=True))
185        self.layers = nn.Sequential(*layers)
187    def forward(self, x: torch.Tensor):
188        return self.layers(x)

Initialize convolution layer weights to

191def weights_init_normal(m):
195    classname = m.__class__.__name__
196    if classname.find("Conv") != -1:
197        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

Load an image and change to RGB if in grey-scale.

200def load_image(path: str):
204    image = Image.open(path)
205    if image.mode != 'RGB':
206        image = Image.new("RGB", image.size).paste(image)
207
208    return image

Dataset to load images

211class ImageDataset(Dataset):

Download dataset and extract data

216    @staticmethod
217    def download(dataset_name: str):

URL

222        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

Download folder

224        root = lab.get_data_path() / 'cycle_gan'
225        if not root.exists():
226            root.mkdir(parents=True)

Download destination

228        archive = root / f'{dataset_name}.zip'

Download file (generally ~100MB)

230        download_file(url, archive)

Extract the archive

232        with zipfile.ZipFile(archive, 'r') as f:
233            f.extractall(root)

Initialize the dataset

  • dataset_name is the name of the dataset
  • transforms_ is the set of image transforms
  • mode is either train or test
235    def __init__(self, dataset_name: str, transforms_, mode: str):

Dataset path

244        root = lab.get_data_path() / 'cycle_gan' / dataset_name

Download if missing

246        if not root.exists():
247            self.download(dataset_name)

Image transforms

250        self.transform = transforms.Compose(transforms_)

Get image paths

253        path_a = root / f'{mode}A'
254        path_b = root / f'{mode}B'
255        self.files_a = sorted(str(f) for f in path_a.iterdir())
256        self.files_b = sorted(str(f) for f in path_b.iterdir())
258    def __getitem__(self, index):

Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.

262        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265    def __len__(self):

Number of images in the dataset

267        return max(len(self.files_a), len(self.files_b))

Replay Buffer

Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.

The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.

This is done to reduce model oscillation.

270class ReplayBuffer:
284    def __init__(self, max_size: int = 50):
285        self.max_size = max_size
286        self.data = []

Add/retrieve an image

288    def push_and_pop(self, data: torch.Tensor):
290        data = data.detach()
291        res = []
292        for element in data:
293            if len(self.data) < self.max_size:
294                self.data.append(element)
295                res.append(element)
296            else:
297                if random.uniform(0, 1) > 0.5:
298                    i = random.randint(0, self.max_size - 1)
299                    res.append(self.data[i].clone())
300                    self.data[i] = element
301                else:
302                    res.append(element)
303        return torch.stack(res)

Configurations

306class Configs(BaseConfigs):

DeviceConfigs will pick a GPU if available

310    device: torch.device = DeviceConfigs()

Hyper-parameters

313    epochs: int = 200
314    dataset_name: str = 'monet2photo'
315    batch_size: int = 1
316
317    data_loader_workers = 8
318
319    learning_rate = 0.0002
320    adam_betas = (0.5, 0.999)
321    decay_start = 100

The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.

325    gan_loss = torch.nn.MSELoss()

L1 loss is used for cycle loss and identity loss

328    cycle_loss = torch.nn.L1Loss()
329    identity_loss = torch.nn.L1Loss()

Image dimensions

332    img_height = 256
333    img_width = 256
334    img_channels = 3

Number of residual blocks in the generator

337    n_residual_blocks = 9

Loss coefficients

340    cyclic_loss_coefficient = 10.0
341    identity_loss_coefficient = 5.
342
343    sample_interval = 500

Models

346    generator_xy: GeneratorResNet
347    generator_yx: GeneratorResNet
348    discriminator_x: Discriminator
349    discriminator_y: Discriminator

Optimizers

352    generator_optimizer: torch.optim.Adam
353    discriminator_optimizer: torch.optim.Adam

Learning rate schedules

356    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

Data loaders

360    dataloader: DataLoader
361    valid_dataloader: DataLoader

Generate samples from test set and save them

363    def sample_images(self, n: int):
365        batch = next(iter(self.valid_dataloader))
366        self.generator_xy.eval()
367        self.generator_yx.eval()
368        with torch.no_grad():
369            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370            gen_y = self.generator_xy(data_x)
371            gen_x = self.generator_yx(data_y)

Arrange images along x-axis

374            data_x = make_grid(data_x, nrow=5, normalize=True)
375            data_y = make_grid(data_y, nrow=5, normalize=True)
376            gen_x = make_grid(gen_x, nrow=5, normalize=True)
377            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Arrange images along y-axis

380            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

Show samples

383        plot_image(image_grid)

Initialize models and data loaders

385    def initialize(self):
389        input_shape = (self.img_channels, self.img_height, self.img_width)

Create the models

392        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.discriminator_x = Discriminator(input_shape).to(self.device)
395        self.discriminator_y = Discriminator(input_shape).to(self.device)

Create the optmizers

398        self.generator_optimizer = torch.optim.Adam(
399            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400            lr=self.learning_rate, betas=self.adam_betas)
401        self.discriminator_optimizer = torch.optim.Adam(
402            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403            lr=self.learning_rate, betas=self.adam_betas)

Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to at end of training.

408        decay_epochs = self.epochs - self.decay_start
409        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

Image transformations

415        transforms_ = [
416            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417            transforms.RandomCrop((self.img_height, self.img_width)),
418            transforms.RandomHorizontalFlip(),
419            transforms.ToTensor(),
420            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421        ]

Training data loader

424        self.dataloader = DataLoader(
425            ImageDataset(self.dataset_name, transforms_, 'train'),
426            batch_size=self.batch_size,
427            shuffle=True,
428            num_workers=self.data_loader_workers,
429        )

Validation data loader

432        self.valid_dataloader = DataLoader(
433            ImageDataset(self.dataset_name, transforms_, "test"),
434            batch_size=5,
435            shuffle=True,
436            num_workers=self.data_loader_workers,
437        )

Training

We aim to solve:

where, translates images from , translates images from , tests if images are from space, tests if images are from space, and

is the generative adversarial loss from the original GAN paper.

is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .

is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.

To solve , discriminators and should ascend on the gradient,

That is descend on negative log-likelihood loss.

In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,

We use least-squares for generators also. The generators should descend on the gradient,

We use generator_xy for and generator_yx for . We use discriminator_x for and discriminator_y for .

439    def run(self):

Replay buffers to keep generated samples

541        gen_x_buffer = ReplayBuffer()
542        gen_y_buffer = ReplayBuffer()

Loop through epochs

545        for epoch in monit.loop(self.epochs):

Loop through the dataset

547            for i, batch in monit.enum('Train', self.dataloader):

Move images to the device

549                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

true labels equal to

552                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553                                         device=self.device, requires_grad=False)

false labels equal to

555                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556                                           device=self.device, requires_grad=False)

Train the generators. This returns the generated images.

560                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

Train discriminators

563                self.optimize_discriminator(data_x, data_y,
564                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565                                            true_labels, false_labels)

Save training statistics and increment the global step counter

568                tracker.save()
569                tracker.add_global_step(max(len(data_x), len(data_y)))

Save images at intervals

572                batches_done = epoch * len(self.dataloader) + i
573                if batches_done % self.sample_interval == 0:

Save models when sampling images

575                    experiment.save_checkpoint()

Sample images

577                    self.sample_images(batches_done)

Update learning rates

580            self.generator_lr_scheduler.step()
581            self.discriminator_lr_scheduler.step()

New line

583            tracker.new_line()

Optimize the generators with identity, gan and cycle losses.

585    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

Change to training mode

591        self.generator_xy.train()
592        self.generator_yx.train()

Identity loss

597        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598                         self.identity_loss(self.generator_xy(data_y), data_y))

Generate images and

601        gen_y = self.generator_xy(data_x)
602        gen_x = self.generator_yx(data_y)

GAN loss

607        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

Cycle loss

615        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616                      self.cycle_loss(self.generator_xy(gen_x), data_y))

Total loss

619        loss_generator = (loss_gan +
620                          self.cyclic_loss_coefficient * loss_cycle +
621                          self.identity_loss_coefficient * loss_identity)

Take a step in the optimizer

624        self.generator_optimizer.zero_grad()
625        loss_generator.backward()
626        self.generator_optimizer.step()

Log losses

629        tracker.add({'loss.generator': loss_generator,
630                     'loss.generator.cycle': loss_cycle,
631                     'loss.generator.gan': loss_gan,
632                     'loss.generator.identity': loss_identity})

Return generated images

635        return gen_x, gen_y

Optimize the discriminators with gan loss.

637    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638                               gen_x: torch.Tensor, gen_y: torch.Tensor,
639                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN Loss

652        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
655                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

Take a step in the optimizer

658        self.discriminator_optimizer.zero_grad()
659        loss_discriminator.backward()
660        self.discriminator_optimizer.step()

Log losses

663        tracker.add({'loss.discriminator': loss_discriminator})

Train Cycle GAN

666def train():

Create configurations

671    conf = Configs()

Create an experiment

673    experiment.create(name='cycle_gan')

Calculate configurations. It will calculate conf.run and all other configs required by it.

676    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

682    experiment.add_pytorch_models(get_modules(conf))

Start and watch the experiment

684    with experiment.start():

Run the training

686        conf.run()

Plot an image with matplotlib

689def plot_image(img: torch.Tensor):
693    from matplotlib import pyplot as plt

Move tensor to CPU

696    img = img.cpu()

Get min and max values of the image for normalization

698    img_min, img_max = img.min(), img.max()

Scale image values to be 0...1

700    img = (img - img_min) / (img_max - img_min + 1e-5)

We have to change the order of dimensions to HWC.

702    img = img.permute(1, 2, 0)

Show Image

704    plt.imshow(img)

We don't need axes

706    plt.axis('off')

Display

708    plt.show()

Evaluate trained Cycle GAN

711def evaluate():

Set the run UUID from the training run

716    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

Create configs object

718    conf = Configs()

Create experiment

720    experiment.create(name='cycle_gan_inference')

Load hyper parameters set for training

722    conf_dict = experiment.load_configs(trained_run_uuid)

Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx .

If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start

731    experiment.configs(conf, conf_dict)
732    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

737    experiment.add_pytorch_models(get_modules(conf))

Specify which run to load from. Loading will actually happen when you call experiment.start

740    experiment.load(trained_run_uuid)

Start the experiment

743    with experiment.start():

Image transformations

745        transforms_ = [
746            transforms.ToTensor(),
747            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748        ]

Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name , if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs

754        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

Get an image from dataset

756        x_image = dataset[10]['x']

Display the image

758        plot_image(x_image)

Evaluation mode

761        conf.generator_xy.eval()
762        conf.generator_yx.eval()

We don't need gradients

765        with torch.no_grad():

Add batch dimension and move to the device we use

767            data = x_image.unsqueeze(0).to(conf.device)
768            generated_y = conf.generator_xy(data)

Display the generated image.

771        plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775    train()

evaluate()