这是 PyTorch 的 PyTorch 实现/教程,该论文使用周期一致性对抗网络进行图像间的非配对转换。
我从 eriklindernoren/pytorch-Gan 那里拿了一些代码。如果你也想查看其他 GAN 变体,这是一个非常好的资源。
Cyc@@le GAN 进行图像到图像的转换。它训练模型将图像从给定分布转换到另一个分布,比如A类和B类的图像,某个分布的图像可以是某种风格或自然的图像。模型不需要 A 和 B 之间的配对图像,每个类别的一组图像就足够了。这非常适合在图像风格、光照变化、图案变化等之间进行切换。例如,将夏天改为冬天,将绘画风格改为照片,将马改为斑马。
Cycle GAN 可训练两个发电机模型和两个鉴别器模型。一个生成器将图像从 A 转换到 B,另一个从 B 转换到 A。判别器测试生成的图像是否真实。
此文件包含模型代码和训练代码。我们还有一台谷歌 Colab 笔记本电脑。
35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module
发电机是一个残余网络。
56class GeneratorResNet(Module):
61 def __init__(self, input_channels: int, n_residual_blocks: int):
62 super().__init__()
第一个块运行卷积并将图像映射到要素地图。输出要素地图的高度和宽度相同,因为我们的内边距为。使用反射填充是因为它可以在边缘处提供更好的图像质量。
inplace=True
inReLU
可以节省一点内存。
70 out_features = 64
71 layers = [
72 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73 nn.InstanceNorm2d(out_features),
74 nn.ReLU(inplace=True),
75 ]
76 in_features = out_features
我们使用步幅为 2 的两个卷积进行向下采样
80 for _ in range(2):
81 out_features *= 2
82 layers += [
83 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84 nn.InstanceNorm2d(out_features),
85 nn.ReLU(inplace=True),
86 ]
87 in_features = out_features
我们来解决这个问题n_residual_blocks
。此模块定义如下。
91 for _ in range(n_residual_blocks):
92 layers += [ResidualBlock(out_features)]
然后对生成的要素地图进行上采样,以匹配原始图像的高度和宽度。
96 for _ in range(2):
97 out_features //= 2
98 layers += [
99 nn.Upsample(scale_factor=2),
100 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101 nn.InstanceNorm2d(out_features),
102 nn.ReLU(inplace=True),
103 ]
104 in_features = out_features
最后,我们将特征图映射到 RGB 图像
107 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
使用层创建顺序模块
110 self.layers = nn.Sequential(*layers)
将权重初始化为
113 self.apply(weights_init_normal)
115 def forward(self, x):
116 return self.layers(x)
这是残差块,有两个卷积层。
119class ResidualBlock(Module):
124 def __init__(self, in_features: int):
125 super().__init__()
126 self.block = nn.Sequential(
127 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128 nn.InstanceNorm2d(in_features),
129 nn.ReLU(inplace=True),
130 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131 nn.InstanceNorm2d(in_features),
132 nn.ReLU(inplace=True),
133 )
135 def forward(self, x: torch.Tensor):
136 return x + self.block(x)
这是鉴别器。
139class Discriminator(Module):
144 def __init__(self, input_shape: Tuple[int, int, int]):
145 super().__init__()
146 channels, height, width = input_shape
判别器的输出也是概率图,无论图像的每个区域是真实的还是生成的
150 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152 self.layers = nn.Sequential(
这些方块中的每一个都会将高度和宽度缩小 2 倍
154 DiscriminatorBlock(channels, 64, normalize=False),
155 DiscriminatorBlock(64, 128),
156 DiscriminatorBlock(128, 256),
157 DiscriminatorBlock(256, 512),
顶部和左侧的零填充以保持输出高度和宽度与内核相同
160 nn.ZeroPad2d((1, 0, 1, 0)),
161 nn.Conv2d(512, 1, kernel_size=4, padding=1)
162 )
将权重初始化为
165 self.apply(weights_init_normal)
167 def forward(self, img):
168 return self.layers(img)
171class DiscriminatorBlock(Module):
179 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180 super().__init__()
181 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182 if normalize:
183 layers.append(nn.InstanceNorm2d(out_filters))
184 layers.append(nn.LeakyReLU(0.2, inplace=True))
185 self.layers = nn.Sequential(*layers)
187 def forward(self, x: torch.Tensor):
188 return self.layers(x)
将卷积层权重初始化为
191def weights_init_normal(m):
195 classname = m.__class__.__name__
196 if classname.find("Conv") != -1:
197 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
加载图像并更改为 RGB(如果为灰度)。
200def load_image(path: str):
204 image = Image.open(path)
205 if image.mode != 'RGB':
206 image = Image.new("RGB", image.size).paste(image)
207
208 return image
211class ImageDataset(Dataset):
216 @staticmethod
217 def download(dataset_name: str):
网址
222 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
下载文件夹
224 root = lab.get_data_path() / 'cycle_gan'
225 if not root.exists():
226 root.mkdir(parents=True)
下载目的地
228 archive = root / f'{dataset_name}.zip'
下载文件(一般约为 100MB)
230 download_file(url, archive)
解压档案
232 with zipfile.ZipFile(archive, 'r') as f:
233 f.extractall(root)
235 def __init__(self, dataset_name: str, transforms_, mode: str):
数据集路径
244 root = lab.get_data_path() / 'cycle_gan' / dataset_name
如果缺少则下载
246 if not root.exists():
247 self.download(dataset_name)
图像变换
250 self.transform = transforms.Compose(transforms_)
获取图像路径
253 path_a = root / f'{mode}A'
254 path_b = root / f'{mode}B'
255 self.files_a = sorted(str(f) for f in path_a.iterdir())
256 self.files_b = sorted(str(f) for f in path_b.iterdir())
258 def __getitem__(self, index):
返回一对图像。这些对被分成一组,在训练中它们不像成对那样起作用。因此,我们总是继续给同样的货币对是可以的。
262 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265 def __len__(self):
数据集中的影像数量
267 return max(len(self.files_a), len(self.files_b))
重播缓冲区用于训练鉴别器。生成的图像被添加到重放缓冲区并从中取样。
重放缓冲区返回新添加的图像,概率为。否则,它会发送一个较旧的生成的图像,并用新生成的图像替换旧的图像。
这样做是为了减少模型振荡。
270class ReplayBuffer:
284 def __init__(self, max_size: int = 50):
285 self.max_size = max_size
286 self.data = []
添加/检索图像
288 def push_and_pop(self, data: torch.Tensor):
290 data = data.detach()
291 res = []
292 for element in data:
293 if len(self.data) < self.max_size:
294 self.data.append(element)
295 res.append(element)
296 else:
297 if random.uniform(0, 1) > 0.5:
298 i = random.randint(0, self.max_size - 1)
299 res.append(self.data[i].clone())
300 self.data[i] = element
301 else:
302 res.append(element)
303 return torch.stack(res)
306class Configs(BaseConfigs):
DeviceConfigs
如果有的话,会选择一个 GPU
310 device: torch.device = DeviceConfigs()
超参数
313 epochs: int = 200
314 dataset_name: str = 'monet2photo'
315 batch_size: int = 1
316
317 data_loader_workers = 8
318
319 learning_rate = 0.0002
320 adam_betas = (0.5, 0.999)
321 decay_start = 100
该论文建议使用最小二乘损失而不是负对数似然,因为人们发现它更稳定。
325 gan_loss = torch.nn.MSELoss()
L1 损失用于周期损失和身份丢失
328 cycle_loss = torch.nn.L1Loss()
329 identity_loss = torch.nn.L1Loss()
图像尺寸
332 img_height = 256
333 img_width = 256
334 img_channels = 3
生成器中的残余块数
337 n_residual_blocks = 9
损失系数
340 cyclic_loss_coefficient = 10.0
341 identity_loss_coefficient = 5.
342
343 sample_interval = 500
模特
346 generator_xy: GeneratorResNet
347 generator_yx: GeneratorResNet
348 discriminator_x: Discriminator
349 discriminator_y: Discriminator
优化器
352 generator_optimizer: torch.optim.Adam
353 discriminator_optimizer: torch.optim.Adam
学习速率表
356 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
数据加载器
360 dataloader: DataLoader
361 valid_dataloader: DataLoader
从测试集生成样本并保存
363 def sample_images(self, n: int):
365 batch = next(iter(self.valid_dataloader))
366 self.generator_xy.eval()
367 self.generator_yx.eval()
368 with torch.no_grad():
369 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370 gen_y = self.generator_xy(data_x)
371 gen_x = self.generator_yx(data_y)
沿 x 轴排列图像
374 data_x = make_grid(data_x, nrow=5, normalize=True)
375 data_y = make_grid(data_y, nrow=5, normalize=True)
376 gen_x = make_grid(gen_x, nrow=5, normalize=True)
377 gen_y = make_grid(gen_y, nrow=5, normalize=True)
沿 y 轴排列图像
380 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
显示样本
383 plot_image(image_grid)
385 def initialize(self):
389 input_shape = (self.img_channels, self.img_height, self.img_width)
创建模型
392 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394 self.discriminator_x = Discriminator(input_shape).to(self.device)
395 self.discriminator_y = Discriminator(input_shape).to(self.device)
创建优化器
398 self.generator_optimizer = torch.optim.Adam(
399 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400 lr=self.learning_rate, betas=self.adam_betas)
401 self.discriminator_optimizer = torch.optim.Adam(
402 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403 lr=self.learning_rate, betas=self.adam_betas)
创建学习速率表。学习率一直保持不变,直到decay_start
各个时代,然后在训练结束时线性降低。
408 decay_epochs = self.epochs - self.decay_start
409 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
图像变换
415 transforms_ = [
416 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417 transforms.RandomCrop((self.img_height, self.img_width)),
418 transforms.RandomHorizontalFlip(),
419 transforms.ToTensor(),
420 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421 ]
训练数据加载器
424 self.dataloader = DataLoader(
425 ImageDataset(self.dataset_name, transforms_, 'train'),
426 batch_size=self.batch_size,
427 shuffle=True,
428 num_workers=self.data_loader_workers,
429 )
验证数据加载器
432 self.valid_dataloader = DataLoader(
433 ImageDataset(self.dataset_name, transforms_, "test"),
434 batch_size=5,
435 shuffle=True,
436 num_workers=self.data_loader_workers,
437 )
我们的目标是解决:
其中,翻译图像,翻译来自的图像,测试图像是否来自太空,测试图像是否来自太空,以及
是原始 GAN 论文产生的对抗损失。
是循环损失,我们试图与之相似和相似。基本上,如果两个生成器(变换)是串联应用的,它应该返回原始图像。这是本文的主要贡献。它训练生成器以生成与原始图像相似的其他分布的图像。如果没有这种损失,可能会产生任何来自分发的损失。现在它需要从的分布中生成一些东西,但仍然具有的属性,这样才能重新生成类似的东西。
是身份丢失。这被用来鼓励映射以保留输入和输出之间的颜色构成。
为了求解,鉴别器和应该在梯度上升,
这取决于负对数似然损失。
为了稳定训练,负对数似然目标被最小二乘损失所取代,即鉴别器的最小二乘误差,用1标记真实图像,将生成的图像标记为0。所以我们想在渐变上下降,
我们也使用最小二乘作为生成器。发电机应该下降到梯度上,
我们使用 fgenerator_xy
or 和 fgenerator_yx
or。我们使用 fdiscriminator_x
or 和 fdiscriminator_y
or。
439 def run(self):
重播缓冲区以保留生成的样本
541 gen_x_buffer = ReplayBuffer()
542 gen_y_buffer = ReplayBuffer()
循环穿越时代
545 for epoch in monit.loop(self.epochs):
循环浏览数据集
547 for i, batch in monit.enum('Train', self.dataloader):
将图像移动到设备
549 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
真实标签等于
552 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553 device=self.device, requires_grad=False)
假标签等于
555 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556 device=self.device, requires_grad=False)
训练发电机。这将返回生成的图像。
560 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
训练鉴别器
563 self.optimize_discriminator(data_x, data_y,
564 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565 true_labels, false_labels)
保存训练统计数据并增加全局步数计数器
568 tracker.save()
569 tracker.add_global_step(max(len(data_x), len(data_y)))
每隔一段时间保存图像
572 batches_done = epoch * len(self.dataloader) + i
573 if batches_done % self.sample_interval == 0:
采样图像时保存模型
575 experiment.save_checkpoint()
样本图片
577 self.sample_images(batches_done)
更新学习率
580 self.generator_lr_scheduler.step()
581 self.discriminator_lr_scheduler.step()
新产品线
583 tracker.new_line()
585 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
改为训练模式
591 self.generator_xy.train()
592 self.generator_yx.train()
身份丢失
597 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598 self.identity_loss(self.generator_xy(data_y), data_y))
生成图像和
601 gen_y = self.generator_xy(data_x)
602 gen_x = self.generator_yx(data_y)
GAN 损失
607 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608 self.gan_loss(self.discriminator_x(gen_x), true_labels))
周期损失
615 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616 self.cycle_loss(self.generator_xy(gen_x), data_y))
总亏损
619 loss_generator = (loss_gan +
620 self.cyclic_loss_coefficient * loss_cycle +
621 self.identity_loss_coefficient * loss_identity)
在优化器中迈出一步
624 self.generator_optimizer.zero_grad()
625 loss_generator.backward()
626 self.generator_optimizer.step()
对数损失
629 tracker.add({'loss.generator': loss_generator,
630 'loss.generator.cycle': loss_cycle,
631 'loss.generator.gan': loss_gan,
632 'loss.generator.identity': loss_identity})
返回生成的图像
635 return gen_x, gen_y
637 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638 gen_x: torch.Tensor, gen_y: torch.Tensor,
639 true_labels: torch.Tensor, false_labels: torch.Tensor):
652 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654 self.gan_loss(self.discriminator_y(data_y), true_labels) +
655 self.gan_loss(self.discriminator_y(gen_y), false_labels))
在优化器中迈出一步
658 self.discriminator_optimizer.zero_grad()
659 loss_discriminator.backward()
660 self.discriminator_optimizer.step()
对数损失
663 tracker.add({'loss.discriminator': loss_discriminator})
666def train():
创建配置
671 conf = Configs()
创建实验
673 experiment.create(name='cycle_gan')
计算配置。它将计算conf.run
和它所需的所有其他配置。
676 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677 conf.initialize()
注册模型以进行保存和加载。get_modules
给出了nn.Modules
in 的字典conf
。您还可以指定模型的自定义字典。
682 experiment.add_pytorch_models(get_modules(conf))
开始观看实验
684 with experiment.start():
运行训练
686 conf.run()
689def plot_image(img: torch.Tensor):
693 from matplotlib import pyplot as plt
将张量移到 CPU
696 img = img.cpu()
获取图像的最小值和最大值以进行归一化
698 img_min, img_max = img.min(), img.max()
我们必须将尺寸顺序更改为 HWC。
702 img = img.permute(1, 2, 0)
显示图片
704 plt.imshow(img)
我们不需要斧头
706 plt.axis('off')
显示
708 plt.show()
711def evaluate():
设置训练跑的跑步 UUID
716 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
创建配置对象
718 conf = Configs()
创建实验
720 experiment.create(name='cycle_gan_inference')
加载为训练设置的超级参数
722 conf_dict = experiment.load_configs(trained_run_uuid)
计算配置。我们指定生成器,'generator_xy', 'generator_yx'
以便它只加载这些生成器及其依赖项。img_channels
将计算device
和之类的配置,因为generator_xy
和需要这些配置generator_yx
。
如果你想要其他参数,dataset_name
你应该在这里指定它们。如果未指定任何内容,则将计算所有配置,包括数据加载器。调用时将计算配置及其依赖关系experiment.start
731 experiment.configs(conf, conf_dict)
732 conf.initialize()
注册模型以进行保存和加载。get_modules
给出了nn.Modules
in 的字典conf
。您还可以指定模型的自定义字典。
737 experiment.add_pytorch_models(get_modules(conf))
指定要从哪个运行中加载。当你打电话时,加载实际上会发生experiment.start
740 experiment.load(trained_run_uuid)
开始实验
743 with experiment.start():
图像变换
745 transforms_ = [
746 transforms.ToTensor(),
747 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748 ]
加载您自己的数据。在这里,我们尝试测试集。我在尝试优胜美地的照片,它们看起来很棒。你可以使用conf.dataset_name
,如果你指定dataset_name
为你想要在调用中计算的东西experiment.configs
754 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
从数据集中获取图像
756 x_image = dataset[10]['x']
显示图像
758 plot_image(x_image)
评估模式
761 conf.generator_xy.eval()
762 conf.generator_yx.eval()
我们不需要渐变
765 with torch.no_grad():
添加批量维度并移动到我们使用的设备
767 data = x_image.unsqueeze(0).to(conf.device)
768 generated_y = conf.generator_xy(data)
显示生成的图像。
771 plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775 train()
评估 ()