This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.
Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.
Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.
This file contains the model code as well as the training code. We also have a Google Colab notebook.
35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module
The generator is a residual network.
56class GeneratorResNet(Module):
61 def __init__(self, input_channels: int, n_residual_blocks: int):
62 super().__init__()
This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.
inplace=True
in ReLU
saves a little bit of memory.
70 out_features = 64
71 layers = [
72 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73 nn.InstanceNorm2d(out_features),
74 nn.ReLU(inplace=True),
75 ]
76 in_features = out_features
We down-sample with two convolutions with stride of 2
80 for _ in range(2):
81 out_features *= 2
82 layers += [
83 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84 nn.InstanceNorm2d(out_features),
85 nn.ReLU(inplace=True),
86 ]
87 in_features = out_features
We take this through n_residual_blocks
. This module is defined below.
91 for _ in range(n_residual_blocks):
92 layers += [ResidualBlock(out_features)]
Then the resulting feature map is up-sampled to match the original image height and width.
96 for _ in range(2):
97 out_features //= 2
98 layers += [
99 nn.Upsample(scale_factor=2),
100 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101 nn.InstanceNorm2d(out_features),
102 nn.ReLU(inplace=True),
103 ]
104 in_features = out_features
Finally we map the feature map to an RGB image
107 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
Create a sequential module with the layers
110 self.layers = nn.Sequential(*layers)
Initialize weights to
113 self.apply(weights_init_normal)
115 def forward(self, x):
116 return self.layers(x)
This is the residual block, with two convolution layers.
119class ResidualBlock(Module):
124 def __init__(self, in_features: int):
125 super().__init__()
126 self.block = nn.Sequential(
127 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128 nn.InstanceNorm2d(in_features),
129 nn.ReLU(inplace=True),
130 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131 nn.InstanceNorm2d(in_features),
132 nn.ReLU(inplace=True),
133 )
135 def forward(self, x: torch.Tensor):
136 return x + self.block(x)
This is the discriminator.
139class Discriminator(Module):
144 def __init__(self, input_shape: Tuple[int, int, int]):
145 super().__init__()
146 channels, height, width = input_shape
Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated
150 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152 self.layers = nn.Sequential(
Each of these blocks will shrink the height and width by a factor of 2
154 DiscriminatorBlock(channels, 64, normalize=False),
155 DiscriminatorBlock(64, 128),
156 DiscriminatorBlock(128, 256),
157 DiscriminatorBlock(256, 512),
Zero pad on top and left to keep the output height and width same with the kernel
160 nn.ZeroPad2d((1, 0, 1, 0)),
161 nn.Conv2d(512, 1, kernel_size=4, padding=1)
162 )
Initialize weights to
165 self.apply(weights_init_normal)
167 def forward(self, img):
168 return self.layers(img)
This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.
It shrinks the height and width of the input feature map by half.
171class DiscriminatorBlock(Module):
179 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180 super().__init__()
181 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182 if normalize:
183 layers.append(nn.InstanceNorm2d(out_filters))
184 layers.append(nn.LeakyReLU(0.2, inplace=True))
185 self.layers = nn.Sequential(*layers)
187 def forward(self, x: torch.Tensor):
188 return self.layers(x)
Initialize convolution layer weights to
191def weights_init_normal(m):
195 classname = m.__class__.__name__
196 if classname.find("Conv") != -1:
197 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
Load an image and change to RGB if in grey-scale.
200def load_image(path: str):
204 image = Image.open(path)
205 if image.mode != 'RGB':
206 image = Image.new("RGB", image.size).paste(image)
207
208 return image
211class ImageDataset(Dataset):
216 @staticmethod
217 def download(dataset_name: str):
URL
222 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
Download folder
224 root = lab.get_data_path() / 'cycle_gan'
225 if not root.exists():
226 root.mkdir(parents=True)
Download destination
228 archive = root / f'{dataset_name}.zip'
Download file (generally ~100MB)
230 download_file(url, archive)
Extract the archive
232 with zipfile.ZipFile(archive, 'r') as f:
233 f.extractall(root)
dataset_name
is the name of the dataset transforms_
is the set of image transforms mode
is either train
or test
235 def __init__(self, dataset_name: str, transforms_, mode: str):
Dataset path
244 root = lab.get_data_path() / 'cycle_gan' / dataset_name
Download if missing
246 if not root.exists():
247 self.download(dataset_name)
Image transforms
250 self.transform = transforms.Compose(transforms_)
Get image paths
253 path_a = root / f'{mode}A'
254 path_b = root / f'{mode}B'
255 self.files_a = sorted(str(f) for f in path_a.iterdir())
256 self.files_b = sorted(str(f) for f in path_b.iterdir())
258 def __getitem__(self, index):
Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.
262 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265 def __len__(self):
Number of images in the dataset
267 return max(len(self.files_a), len(self.files_b))
Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.
The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.
This is done to reduce model oscillation.
270class ReplayBuffer:
284 def __init__(self, max_size: int = 50):
285 self.max_size = max_size
286 self.data = []
Add/retrieve an image
288 def push_and_pop(self, data: torch.Tensor):
290 data = data.detach()
291 res = []
292 for element in data:
293 if len(self.data) < self.max_size:
294 self.data.append(element)
295 res.append(element)
296 else:
297 if random.uniform(0, 1) > 0.5:
298 i = random.randint(0, self.max_size - 1)
299 res.append(self.data[i].clone())
300 self.data[i] = element
301 else:
302 res.append(element)
303 return torch.stack(res)
306class Configs(BaseConfigs):
DeviceConfigs
will pick a GPU if available
310 device: torch.device = DeviceConfigs()
Hyper-parameters
313 epochs: int = 200
314 dataset_name: str = 'monet2photo'
315 batch_size: int = 1
316
317 data_loader_workers = 8
318
319 learning_rate = 0.0002
320 adam_betas = (0.5, 0.999)
321 decay_start = 100
The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.
325 gan_loss = torch.nn.MSELoss()
L1 loss is used for cycle loss and identity loss
328 cycle_loss = torch.nn.L1Loss()
329 identity_loss = torch.nn.L1Loss()
Image dimensions
332 img_height = 256
333 img_width = 256
334 img_channels = 3
Number of residual blocks in the generator
337 n_residual_blocks = 9
Loss coefficients
340 cyclic_loss_coefficient = 10.0
341 identity_loss_coefficient = 5.
342
343 sample_interval = 500
Models
346 generator_xy: GeneratorResNet
347 generator_yx: GeneratorResNet
348 discriminator_x: Discriminator
349 discriminator_y: Discriminator
Optimizers
352 generator_optimizer: torch.optim.Adam
353 discriminator_optimizer: torch.optim.Adam
Learning rate schedules
356 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
Data loaders
360 dataloader: DataLoader
361 valid_dataloader: DataLoader
Generate samples from test set and save them
363 def sample_images(self, n: int):
365 batch = next(iter(self.valid_dataloader))
366 self.generator_xy.eval()
367 self.generator_yx.eval()
368 with torch.no_grad():
369 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370 gen_y = self.generator_xy(data_x)
371 gen_x = self.generator_yx(data_y)
Arrange images along x-axis
374 data_x = make_grid(data_x, nrow=5, normalize=True)
375 data_y = make_grid(data_y, nrow=5, normalize=True)
376 gen_x = make_grid(gen_x, nrow=5, normalize=True)
377 gen_y = make_grid(gen_y, nrow=5, normalize=True)
Arrange images along y-axis
380 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
Show samples
383 plot_image(image_grid)
385 def initialize(self):
389 input_shape = (self.img_channels, self.img_height, self.img_width)
Create the models
392 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394 self.discriminator_x = Discriminator(input_shape).to(self.device)
395 self.discriminator_y = Discriminator(input_shape).to(self.device)
Create the optmizers
398 self.generator_optimizer = torch.optim.Adam(
399 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400 lr=self.learning_rate, betas=self.adam_betas)
401 self.discriminator_optimizer = torch.optim.Adam(
402 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403 lr=self.learning_rate, betas=self.adam_betas)
Create the learning rate schedules. The learning rate stars flat until decay_start
epochs, and then linearly reduce to at end of training.
408 decay_epochs = self.epochs - self.decay_start
409 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
Image transformations
415 transforms_ = [
416 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417 transforms.RandomCrop((self.img_height, self.img_width)),
418 transforms.RandomHorizontalFlip(),
419 transforms.ToTensor(),
420 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421 ]
Training data loader
424 self.dataloader = DataLoader(
425 ImageDataset(self.dataset_name, transforms_, 'train'),
426 batch_size=self.batch_size,
427 shuffle=True,
428 num_workers=self.data_loader_workers,
429 )
Validation data loader
432 self.valid_dataloader = DataLoader(
433 ImageDataset(self.dataset_name, transforms_, "test"),
434 batch_size=5,
435 shuffle=True,
436 num_workers=self.data_loader_workers,
437 )
We aim to solve:
where, translates images from , translates images from , tests if images are from space, tests if images are from space, and
is the generative adversarial loss from the original GAN paper.
is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .
is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.
To solve , discriminators and should ascend on the gradient,
That is descend on negative log-likelihood loss.
In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,
We use least-squares for generators also. The generators should descend on the gradient,
We use generator_xy
for and generator_yx
for . We use discriminator_x
for and discriminator_y
for .
439 def run(self):
Replay buffers to keep generated samples
541 gen_x_buffer = ReplayBuffer()
542 gen_y_buffer = ReplayBuffer()
Loop through epochs
545 for epoch in monit.loop(self.epochs):
Loop through the dataset
547 for i, batch in monit.enum('Train', self.dataloader):
Move images to the device
549 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
true labels equal to
552 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553 device=self.device, requires_grad=False)
false labels equal to
555 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556 device=self.device, requires_grad=False)
Train the generators. This returns the generated images.
560 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
Train discriminators
563 self.optimize_discriminator(data_x, data_y,
564 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565 true_labels, false_labels)
Save training statistics and increment the global step counter
568 tracker.save()
569 tracker.add_global_step(max(len(data_x), len(data_y)))
Save images at intervals
572 batches_done = epoch * len(self.dataloader) + i
573 if batches_done % self.sample_interval == 0:
Save models when sampling images
575 experiment.save_checkpoint()
Sample images
577 self.sample_images(batches_done)
Update learning rates
580 self.generator_lr_scheduler.step()
581 self.discriminator_lr_scheduler.step()
New line
583 tracker.new_line()
585 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
Change to training mode
591 self.generator_xy.train()
592 self.generator_yx.train()
Identity loss
597 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598 self.identity_loss(self.generator_xy(data_y), data_y))
Generate images and
601 gen_y = self.generator_xy(data_x)
602 gen_x = self.generator_yx(data_y)
GAN loss
607 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608 self.gan_loss(self.discriminator_x(gen_x), true_labels))
Cycle loss
615 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616 self.cycle_loss(self.generator_xy(gen_x), data_y))
Total loss
619 loss_generator = (loss_gan +
620 self.cyclic_loss_coefficient * loss_cycle +
621 self.identity_loss_coefficient * loss_identity)
Take a step in the optimizer
624 self.generator_optimizer.zero_grad()
625 loss_generator.backward()
626 self.generator_optimizer.step()
Log losses
629 tracker.add({'loss.generator': loss_generator,
630 'loss.generator.cycle': loss_cycle,
631 'loss.generator.gan': loss_gan,
632 'loss.generator.identity': loss_identity})
Return generated images
635 return gen_x, gen_y
637 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638 gen_x: torch.Tensor, gen_y: torch.Tensor,
639 true_labels: torch.Tensor, false_labels: torch.Tensor):
652 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654 self.gan_loss(self.discriminator_y(data_y), true_labels) +
655 self.gan_loss(self.discriminator_y(gen_y), false_labels))
Take a step in the optimizer
658 self.discriminator_optimizer.zero_grad()
659 loss_discriminator.backward()
660 self.discriminator_optimizer.step()
Log losses
663 tracker.add({'loss.discriminator': loss_discriminator})
666def train():
Create configurations
671 conf = Configs()
Create an experiment
673 experiment.create(name='cycle_gan')
Calculate configurations. It will calculate conf.run
and all other configs required by it.
676 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677 conf.initialize()
Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
682 experiment.add_pytorch_models(get_modules(conf))
Start and watch the experiment
684 with experiment.start():
Run the training
686 conf.run()
689def plot_image(img: torch.Tensor):
693 from matplotlib import pyplot as plt
Move tensor to CPU
696 img = img.cpu()
Get min and max values of the image for normalization
698 img_min, img_max = img.min(), img.max()
We have to change the order of dimensions to HWC.
702 img = img.permute(1, 2, 0)
Show Image
704 plt.imshow(img)
We don't need axes
706 plt.axis('off')
Display
708 plt.show()
711def evaluate():
Set the run UUID from the training run
716 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
Create configs object
718 conf = Configs()
Create experiment
720 experiment.create(name='cycle_gan_inference')
Load hyper parameters set for training
722 conf_dict = experiment.load_configs(trained_run_uuid)
Calculate configurations. We specify the generators 'generator_xy', 'generator_yx'
so that it only loads those and their dependencies. Configs like device
and img_channels
will be calculated, since these are required by generator_xy
and generator_yx
.
If you want other parameters like dataset_name
you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start
731 experiment.configs(conf, conf_dict)
732 conf.initialize()
Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
737 experiment.add_pytorch_models(get_modules(conf))
Specify which run to load from. Loading will actually happen when you call experiment.start
740 experiment.load(trained_run_uuid)
Start the experiment
743 with experiment.start():
Image transformations
745 transforms_ = [
746 transforms.ToTensor(),
747 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748 ]
Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name
, if you specified dataset_name
as something you wanted to be calculated in the call to experiment.configs
754 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
Get an image from dataset
756 x_image = dataset[10]['x']
Display the image
758 plot_image(x_image)
Evaluation mode
761 conf.generator_xy.eval()
762 conf.generator_yx.eval()
We don't need gradients
765 with torch.no_grad():
Add batch dimension and move to the device we use
767 data = x_image.unsqueeze(0).to(conf.device)
768 generated_y = conf.generator_xy(data)
Display the generated image.
771 plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775 train()
evaluate()