This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.
Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.
Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.
This file contains the model code as well as the training code. We also have a Google Colab notebook.
35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_nn.helpers.device import DeviceConfigs
The generator is a residual network.
55class GeneratorResNet(nn.Module):
60 def __init__(self, input_channels: int, n_residual_blocks: int):
61 super().__init__()
This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.
inplace=True
in ReLU
saves a little bit of memory.
69 out_features = 64
70 layers = [
71 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
72 nn.InstanceNorm2d(out_features),
73 nn.ReLU(inplace=True),
74 ]
75 in_features = out_features
We down-sample with two convolutions with stride of 2
79 for _ in range(2):
80 out_features *= 2
81 layers += [
82 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
83 nn.InstanceNorm2d(out_features),
84 nn.ReLU(inplace=True),
85 ]
86 in_features = out_features
We take this through n_residual_blocks
. This module is defined below.
90 for _ in range(n_residual_blocks):
91 layers += [ResidualBlock(out_features)]
Then the resulting feature map is up-sampled to match the original image height and width.
95 for _ in range(2):
96 out_features //= 2
97 layers += [
98 nn.Upsample(scale_factor=2),
99 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
100 nn.InstanceNorm2d(out_features),
101 nn.ReLU(inplace=True),
102 ]
103 in_features = out_features
Finally we map the feature map to an RGB image
106 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
Create a sequential module with the layers
109 self.layers = nn.Sequential(*layers)
Initialize weights to
112 self.apply(weights_init_normal)
114 def forward(self, x):
115 return self.layers(x)
This is the residual block, with two convolution layers.
118class ResidualBlock(nn.Module):
123 def __init__(self, in_features: int):
124 super().__init__()
125 self.block = nn.Sequential(
126 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
127 nn.InstanceNorm2d(in_features),
128 nn.ReLU(inplace=True),
129 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
130 nn.InstanceNorm2d(in_features),
131 nn.ReLU(inplace=True),
132 )
134 def forward(self, x: torch.Tensor):
135 return x + self.block(x)
This is the discriminator.
138class Discriminator(nn.Module):
143 def __init__(self, input_shape: Tuple[int, int, int]):
144 super().__init__()
145 channels, height, width = input_shape
Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated
149 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
150
151 self.layers = nn.Sequential(
Each of these blocks will shrink the height and width by a factor of 2
153 DiscriminatorBlock(channels, 64, normalize=False),
154 DiscriminatorBlock(64, 128),
155 DiscriminatorBlock(128, 256),
156 DiscriminatorBlock(256, 512),
Zero pad on top and left to keep the output height and width same with the kernel
159 nn.ZeroPad2d((1, 0, 1, 0)),
160 nn.Conv2d(512, 1, kernel_size=4, padding=1)
161 )
Initialize weights to
164 self.apply(weights_init_normal)
166 def forward(self, img):
167 return self.layers(img)
This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.
It shrinks the height and width of the input feature map by half.
170class DiscriminatorBlock(nn.Module):
178 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
179 super().__init__()
180 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
181 if normalize:
182 layers.append(nn.InstanceNorm2d(out_filters))
183 layers.append(nn.LeakyReLU(0.2, inplace=True))
184 self.layers = nn.Sequential(*layers)
186 def forward(self, x: torch.Tensor):
187 return self.layers(x)
Initialize convolution layer weights to
190def weights_init_normal(m):
194 classname = m.__class__.__name__
195 if classname.find("Conv") != -1:
196 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
Load an image and change to RGB if in grey-scale.
199def load_image(path: str):
203 image = Image.open(path)
204 if image.mode != 'RGB':
205 image = Image.new("RGB", image.size).paste(image)
206
207 return image
210class ImageDataset(Dataset):
215 @staticmethod
216 def download(dataset_name: str):
URL
221 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
Download folder
223 root = lab.get_data_path() / 'cycle_gan'
224 if not root.exists():
225 root.mkdir(parents=True)
Download destination
227 archive = root / f'{dataset_name}.zip'
Download file (generally ~100MB)
229 download_file(url, archive)
Extract the archive
231 with zipfile.ZipFile(archive, 'r') as f:
232 f.extractall(root)
dataset_name
is the name of the dataset transforms_
is the set of image transforms mode
is either train
or test
234 def __init__(self, dataset_name: str, transforms_, mode: str):
Dataset path
243 root = lab.get_data_path() / 'cycle_gan' / dataset_name
Download if missing
245 if not root.exists():
246 self.download(dataset_name)
Image transforms
249 self.transform = transforms.Compose(transforms_)
Get image paths
252 path_a = root / f'{mode}A'
253 path_b = root / f'{mode}B'
254 self.files_a = sorted(str(f) for f in path_a.iterdir())
255 self.files_b = sorted(str(f) for f in path_b.iterdir())
257 def __getitem__(self, index):
Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.
261 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
262 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
264 def __len__(self):
Number of images in the dataset
266 return max(len(self.files_a), len(self.files_b))
Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.
The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.
This is done to reduce model oscillation.
269class ReplayBuffer:
283 def __init__(self, max_size: int = 50):
284 self.max_size = max_size
285 self.data = []
Add/retrieve an image
287 def push_and_pop(self, data: torch.Tensor):
289 data = data.detach()
290 res = []
291 for element in data:
292 if len(self.data) < self.max_size:
293 self.data.append(element)
294 res.append(element)
295 else:
296 if random.uniform(0, 1) > 0.5:
297 i = random.randint(0, self.max_size - 1)
298 res.append(self.data[i].clone())
299 self.data[i] = element
300 else:
301 res.append(element)
302 return torch.stack(res)
305class Configs(BaseConfigs):
DeviceConfigs
will pick a GPU if available
309 device: torch.device = DeviceConfigs()
Hyper-parameters
312 epochs: int = 200
313 dataset_name: str = 'monet2photo'
314 batch_size: int = 1
315
316 data_loader_workers = 8
317
318 learning_rate = 0.0002
319 adam_betas = (0.5, 0.999)
320 decay_start = 100
The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.
324 gan_loss = torch.nn.MSELoss()
L1 loss is used for cycle loss and identity loss
327 cycle_loss = torch.nn.L1Loss()
328 identity_loss = torch.nn.L1Loss()
Image dimensions
331 img_height = 256
332 img_width = 256
333 img_channels = 3
Number of residual blocks in the generator
336 n_residual_blocks = 9
Loss coefficients
339 cyclic_loss_coefficient = 10.0
340 identity_loss_coefficient = 5.
341
342 sample_interval = 500
Models
345 generator_xy: GeneratorResNet
346 generator_yx: GeneratorResNet
347 discriminator_x: Discriminator
348 discriminator_y: Discriminator
Optimizers
351 generator_optimizer: torch.optim.Adam
352 discriminator_optimizer: torch.optim.Adam
Learning rate schedules
355 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
356 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
Data loaders
359 dataloader: DataLoader
360 valid_dataloader: DataLoader
Generate samples from test set and save them
362 def sample_images(self, n: int):
364 batch = next(iter(self.valid_dataloader))
365 self.generator_xy.eval()
366 self.generator_yx.eval()
367 with torch.no_grad():
368 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
369 gen_y = self.generator_xy(data_x)
370 gen_x = self.generator_yx(data_y)
Arrange images along x-axis
373 data_x = make_grid(data_x, nrow=5, normalize=True)
374 data_y = make_grid(data_y, nrow=5, normalize=True)
375 gen_x = make_grid(gen_x, nrow=5, normalize=True)
376 gen_y = make_grid(gen_y, nrow=5, normalize=True)
Arrange images along y-axis
379 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
Show samples
382 plot_image(image_grid)
384 def initialize(self):
388 input_shape = (self.img_channels, self.img_height, self.img_width)
Create the models
391 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
392 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393 self.discriminator_x = Discriminator(input_shape).to(self.device)
394 self.discriminator_y = Discriminator(input_shape).to(self.device)
Create the optmizers
397 self.generator_optimizer = torch.optim.Adam(
398 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
399 lr=self.learning_rate, betas=self.adam_betas)
400 self.discriminator_optimizer = torch.optim.Adam(
401 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
402 lr=self.learning_rate, betas=self.adam_betas)
Create the learning rate schedules. The learning rate stars flat until decay_start
epochs, and then linearly reduce to at end of training.
407 decay_epochs = self.epochs - self.decay_start
408 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
409 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
410 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
Image transformations
414 transforms_ = [
415 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
416 transforms.RandomCrop((self.img_height, self.img_width)),
417 transforms.RandomHorizontalFlip(),
418 transforms.ToTensor(),
419 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
420 ]
Training data loader
423 self.dataloader = DataLoader(
424 ImageDataset(self.dataset_name, transforms_, 'train'),
425 batch_size=self.batch_size,
426 shuffle=True,
427 num_workers=self.data_loader_workers,
428 )
Validation data loader
431 self.valid_dataloader = DataLoader(
432 ImageDataset(self.dataset_name, transforms_, "test"),
433 batch_size=5,
434 shuffle=True,
435 num_workers=self.data_loader_workers,
436 )
We aim to solve:
where, translates images from , translates images from , tests if images are from space, tests if images are from space, and
is the generative adversarial loss from the original GAN paper.
is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .
is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.
To solve , discriminators and should ascend on the gradient,
That is descend on negative log-likelihood loss.
In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,
We use least-squares for generators also. The generators should descend on the gradient,
We use generator_xy
for and generator_yx
for . We use discriminator_x
for and discriminator_y
for .
438 def run(self):
Replay buffers to keep generated samples
540 gen_x_buffer = ReplayBuffer()
541 gen_y_buffer = ReplayBuffer()
Loop through epochs
544 for epoch in monit.loop(self.epochs):
Loop through the dataset
546 for i, batch in monit.enum('Train', self.dataloader):
Move images to the device
548 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
true labels equal to
551 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
552 device=self.device, requires_grad=False)
false labels equal to
554 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
555 device=self.device, requires_grad=False)
Train the generators. This returns the generated images.
559 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
Train discriminators
562 self.optimize_discriminator(data_x, data_y,
563 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
564 true_labels, false_labels)
Save training statistics and increment the global step counter
567 tracker.save()
568 tracker.add_global_step(max(len(data_x), len(data_y)))
Save images at intervals
571 batches_done = epoch * len(self.dataloader) + i
572 if batches_done % self.sample_interval == 0:
Sample images
574 self.sample_images(batches_done)
Update learning rates
577 self.generator_lr_scheduler.step()
578 self.discriminator_lr_scheduler.step()
New line
580 tracker.new_line()
582 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
Change to training mode
588 self.generator_xy.train()
589 self.generator_yx.train()
Identity loss
594 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
595 self.identity_loss(self.generator_xy(data_y), data_y))
Generate images and
598 gen_y = self.generator_xy(data_x)
599 gen_x = self.generator_yx(data_y)
GAN loss
604 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
605 self.gan_loss(self.discriminator_x(gen_x), true_labels))
Cycle loss
612 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
613 self.cycle_loss(self.generator_xy(gen_x), data_y))
Total loss
616 loss_generator = (loss_gan +
617 self.cyclic_loss_coefficient * loss_cycle +
618 self.identity_loss_coefficient * loss_identity)
Take a step in the optimizer
621 self.generator_optimizer.zero_grad()
622 loss_generator.backward()
623 self.generator_optimizer.step()
Log losses
626 tracker.add({'loss.generator': loss_generator,
627 'loss.generator.cycle': loss_cycle,
628 'loss.generator.gan': loss_gan,
629 'loss.generator.identity': loss_identity})
Return generated images
632 return gen_x, gen_y
634 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
635 gen_x: torch.Tensor, gen_y: torch.Tensor,
636 true_labels: torch.Tensor, false_labels: torch.Tensor):
649 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
650 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
651 self.gan_loss(self.discriminator_y(data_y), true_labels) +
652 self.gan_loss(self.discriminator_y(gen_y), false_labels))
Take a step in the optimizer
655 self.discriminator_optimizer.zero_grad()
656 loss_discriminator.backward()
657 self.discriminator_optimizer.step()
Log losses
660 tracker.add({'loss.discriminator': loss_discriminator})
663def train():
Create configurations
668 conf = Configs()
Create an experiment
670 experiment.create(name='cycle_gan')
Calculate configurations. It will calculate conf.run
and all other configs required by it.
673 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
674 conf.initialize()
Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
679 experiment.add_pytorch_models(get_modules(conf))
Start and watch the experiment
681 with experiment.start():
Run the training
683 conf.run()
686def plot_image(img: torch.Tensor):
690 from matplotlib import pyplot as plt
Move tensor to CPU
693 img = img.cpu()
Get min and max values of the image for normalization
695 img_min, img_max = img.min(), img.max()
We have to change the order of dimensions to HWC.
699 img = img.permute(1, 2, 0)
Show Image
701 plt.imshow(img)
We don't need axes
703 plt.axis('off')
Display
705 plt.show()
708def evaluate():
Set the run UUID from the training run
713 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
Create configs object
715 conf = Configs()
Create experiment
717 experiment.create(name='cycle_gan_inference')
Load hyper parameters set for training
719 conf_dict = experiment.load_configs(trained_run_uuid)
Calculate configurations. We specify the generators 'generator_xy', 'generator_yx'
so that it only loads those and their dependencies. Configs like device
and img_channels
will be calculated, since these are required by generator_xy
and generator_yx
.
If you want other parameters like dataset_name
you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start
728 experiment.configs(conf, conf_dict)
729 conf.initialize()
Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
734 experiment.add_pytorch_models(get_modules(conf))
Specify which run to load from. Loading will actually happen when you call experiment.start
737 experiment.load(trained_run_uuid)
Start the experiment
740 with experiment.start():
Image transformations
742 transforms_ = [
743 transforms.ToTensor(),
744 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
745 ]
Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name
, if you specified dataset_name
as something you wanted to be calculated in the call to experiment.configs
751 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
Get an image from dataset
753 x_image = dataset[10]['x']
Display the image
755 plot_image(x_image)
Evaluation mode
758 conf.generator_xy.eval()
759 conf.generator_yx.eval()
We don't need gradients
762 with torch.no_grad():
Add batch dimension and move to the device we use
764 data = x_image.unsqueeze(0).to(conf.device)
765 generated_y = conf.generator_xy(data)
Display the generated image.
768 plot_image(generated_y[0].cpu())
769
770
771if __name__ == '__main__':
772 train()
evaluate()