循环 GAN

这是 PyTorch 的 PyTorch 实现/教程,该论文使用周期一致性对抗网络进行图像间的非配对转换

我从 eriklindernoren/pytorch-Gan 那里拿了一些代码。如果你也想查看其他 GAN 变体,这是一个非常好的资源。

Cyc@@

le GAN 进行图像到图像的转换。它训练模型将图像从给定分布转换到另一个分布,比如A类和B类的图像,某个分布的图像可以是某种风格或自然的图像。模型不需要 A 和 B 之间的配对图像,每个类别的一组图像就足够了。这非常适合在图像风格、光照变化、图案变化等之间进行切换。例如,将夏天改为冬天,将绘画风格改为照片,将马改为斑马。

Cycle GAN 可训练两个发电机模型和两个鉴别器模型。一个生成器将图像从 A 转换到 B,另一个从 B 转换到 A。判别器测试生成的图像是否真实。

此文件包含模型代码和训练代码。我们还有一台谷歌 Colab 笔记本电脑。

Open In Colab

35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module

发电机是一个残余网络。

56class GeneratorResNet(Module):
61    def __init__(self, input_channels: int, n_residual_blocks: int):
62        super().__init__()

第一个块运行卷积并将图像映射到要素地图。输出要素地图的高度和宽度相同,因为我们的内边距为。使用反射填充是因为它可以在边缘处提供更好的图像质量。

inplace=True inReLU 可以节省一点内存。

70        out_features = 64
71        layers = [
72            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73            nn.InstanceNorm2d(out_features),
74            nn.ReLU(inplace=True),
75        ]
76        in_features = out_features

我们使用步幅为 2 的两个卷积进行向下采样

80        for _ in range(2):
81            out_features *= 2
82            layers += [
83                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84                nn.InstanceNorm2d(out_features),
85                nn.ReLU(inplace=True),
86            ]
87            in_features = out_features

我们来解决这个问题n_residual_blocks 。此模块定义如下。

91        for _ in range(n_residual_blocks):
92            layers += [ResidualBlock(out_features)]

然后对生成的要素地图进行上采样,以匹配原始图像的高度和宽度。

96        for _ in range(2):
97            out_features //= 2
98            layers += [
99                nn.Upsample(scale_factor=2),
100                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101                nn.InstanceNorm2d(out_features),
102                nn.ReLU(inplace=True),
103            ]
104            in_features = out_features

最后,我们将特征图映射到 RGB 图像

107        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

使用层创建顺序模块

110        self.layers = nn.Sequential(*layers)

将权重初始化为

113        self.apply(weights_init_normal)
115    def forward(self, x):
116        return self.layers(x)

这是残差块,有两个卷积层。

119class ResidualBlock(Module):
124    def __init__(self, in_features: int):
125        super().__init__()
126        self.block = nn.Sequential(
127            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128            nn.InstanceNorm2d(in_features),
129            nn.ReLU(inplace=True),
130            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131            nn.InstanceNorm2d(in_features),
132            nn.ReLU(inplace=True),
133        )
135    def forward(self, x: torch.Tensor):
136        return x + self.block(x)

这是鉴别器。

139class Discriminator(Module):
144    def __init__(self, input_shape: Tuple[int, int, int]):
145        super().__init__()
146        channels, height, width = input_shape

判别器的输出也是概率图,无论图像的每个区域是真实的还是生成的

150        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152        self.layers = nn.Sequential(

这些方块中的每一个都会将高度和宽度缩小 2 倍

154            DiscriminatorBlock(channels, 64, normalize=False),
155            DiscriminatorBlock(64, 128),
156            DiscriminatorBlock(128, 256),
157            DiscriminatorBlock(256, 512),

顶部和左侧的零填充以保持输出高度和宽度与内核相同

160            nn.ZeroPad2d((1, 0, 1, 0)),
161            nn.Conv2d(512, 1, kernel_size=4, padding=1)
162        )

将权重初始化为

165        self.apply(weights_init_normal)
167    def forward(self, img):
168        return self.layers(img)

这是鉴别器块模块。它执行卷积、可选归一化和泄漏的 RelU。

它将输入要素地图的高度和宽度缩小一半。

171class DiscriminatorBlock(Module):
179    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180        super().__init__()
181        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182        if normalize:
183            layers.append(nn.InstanceNorm2d(out_filters))
184        layers.append(nn.LeakyReLU(0.2, inplace=True))
185        self.layers = nn.Sequential(*layers)
187    def forward(self, x: torch.Tensor):
188        return self.layers(x)

将卷积层权重初始化为

191def weights_init_normal(m):
195    classname = m.__class__.__name__
196    if classname.find("Conv") != -1:
197        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

加载图像并更改为 RGB(如果为灰度)。

200def load_image(path: str):
204    image = Image.open(path)
205    if image.mode != 'RGB':
206        image = Image.new("RGB", image.size).paste(image)
207
208    return image

用于加载图像的数据集

211class ImageDataset(Dataset):

下载数据集并提取数据

216    @staticmethod
217    def download(dataset_name: str):

网址

222        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

下载文件夹

224        root = lab.get_data_path() / 'cycle_gan'
225        if not root.exists():
226            root.mkdir(parents=True)

下载目的地

228        archive = root / f'{dataset_name}.zip'

下载文件(一般约为 100MB)

230        download_file(url, archive)

解压档案

232        with zipfile.ZipFile(archive, 'r') as f:
233            f.extractall(root)

初始化数据集

  • dataset_name 是数据集的名称
  • transforms_ 是图像变换的集合
  • modetraintest
235    def __init__(self, dataset_name: str, transforms_, mode: str):

数据集路径

244        root = lab.get_data_path() / 'cycle_gan' / dataset_name

如果缺少则下载

246        if not root.exists():
247            self.download(dataset_name)

图像变换

250        self.transform = transforms.Compose(transforms_)

获取图像路径

253        path_a = root / f'{mode}A'
254        path_b = root / f'{mode}B'
255        self.files_a = sorted(str(f) for f in path_a.iterdir())
256        self.files_b = sorted(str(f) for f in path_b.iterdir())
258    def __getitem__(self, index):

返回一对图像。这些对被分成一组,在训练中它们不像成对那样起作用。因此,我们总是继续给同样的货币对是可以的。

262        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265    def __len__(self):

数据集中的影像数量

267        return max(len(self.files_a), len(self.files_b))

重播缓冲区

重播缓冲区用于训练鉴别器。生成的图像被添加到重放缓冲区并从中取样。

重放缓冲区返回新添加的图像,概率为。否则,它会发送一个较旧的生成的图像,并用新生成的图像替换旧的图像。

这样做是为了减少模型振荡。

270class ReplayBuffer:
284    def __init__(self, max_size: int = 50):
285        self.max_size = max_size
286        self.data = []

添加/检索图像

288    def push_and_pop(self, data: torch.Tensor):
290        data = data.detach()
291        res = []
292        for element in data:
293            if len(self.data) < self.max_size:
294                self.data.append(element)
295                res.append(element)
296            else:
297                if random.uniform(0, 1) > 0.5:
298                    i = random.randint(0, self.max_size - 1)
299                    res.append(self.data[i].clone())
300                    self.data[i] = element
301                else:
302                    res.append(element)
303        return torch.stack(res)

配置

306class Configs(BaseConfigs):

DeviceConfigs 如果有的话,会选择一个 GPU

310    device: torch.device = DeviceConfigs()

超参数

313    epochs: int = 200
314    dataset_name: str = 'monet2photo'
315    batch_size: int = 1
316
317    data_loader_workers = 8
318
319    learning_rate = 0.0002
320    adam_betas = (0.5, 0.999)
321    decay_start = 100

该论文建议使用最小二乘损失而不是负对数似然,因为人们发现它更稳定。

325    gan_loss = torch.nn.MSELoss()

L1 损失用于周期损失和身份丢失

328    cycle_loss = torch.nn.L1Loss()
329    identity_loss = torch.nn.L1Loss()

图像尺寸

332    img_height = 256
333    img_width = 256
334    img_channels = 3

生成器中的残余块数

337    n_residual_blocks = 9

损失系数

340    cyclic_loss_coefficient = 10.0
341    identity_loss_coefficient = 5.
342
343    sample_interval = 500

模特

346    generator_xy: GeneratorResNet
347    generator_yx: GeneratorResNet
348    discriminator_x: Discriminator
349    discriminator_y: Discriminator

优化器

352    generator_optimizer: torch.optim.Adam
353    discriminator_optimizer: torch.optim.Adam

学习速率表

356    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

数据加载器

360    dataloader: DataLoader
361    valid_dataloader: DataLoader

从测试集生成样本并保存

363    def sample_images(self, n: int):
365        batch = next(iter(self.valid_dataloader))
366        self.generator_xy.eval()
367        self.generator_yx.eval()
368        with torch.no_grad():
369            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370            gen_y = self.generator_xy(data_x)
371            gen_x = self.generator_yx(data_y)

沿 x 轴排列图像

374            data_x = make_grid(data_x, nrow=5, normalize=True)
375            data_y = make_grid(data_y, nrow=5, normalize=True)
376            gen_x = make_grid(gen_x, nrow=5, normalize=True)
377            gen_y = make_grid(gen_y, nrow=5, normalize=True)

沿 y 轴排列图像

380            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

显示样本

383        plot_image(image_grid)

初始化模型和数据加载器

385    def initialize(self):
389        input_shape = (self.img_channels, self.img_height, self.img_width)

创建模型

392        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.discriminator_x = Discriminator(input_shape).to(self.device)
395        self.discriminator_y = Discriminator(input_shape).to(self.device)

创建优化器

398        self.generator_optimizer = torch.optim.Adam(
399            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400            lr=self.learning_rate, betas=self.adam_betas)
401        self.discriminator_optimizer = torch.optim.Adam(
402            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403            lr=self.learning_rate, betas=self.adam_betas)

创建学习速率表。学习率一直保持不变,直到decay_start 各个时代,然后在训练结束时线性降低。

408        decay_epochs = self.epochs - self.decay_start
409        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

图像变换

415        transforms_ = [
416            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417            transforms.RandomCrop((self.img_height, self.img_width)),
418            transforms.RandomHorizontalFlip(),
419            transforms.ToTensor(),
420            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421        ]

训练数据加载器

424        self.dataloader = DataLoader(
425            ImageDataset(self.dataset_name, transforms_, 'train'),
426            batch_size=self.batch_size,
427            shuffle=True,
428            num_workers=self.data_loader_workers,
429        )

验证数据加载器

432        self.valid_dataloader = DataLoader(
433            ImageDataset(self.dataset_name, transforms_, "test"),
434            batch_size=5,
435            shuffle=True,
436            num_workers=self.data_loader_workers,
437        )

训练

我们的目标是解决:

其中,翻译图像翻译来自的图像测试图像是否来自太空,测试图像是否来自太空,以及

是原始 GAN 论文产生的对抗损失。

是循环损失,我们试图与之相似和相似。基本上,如果两个生成器(变换)是串联应用的,它应该返回原始图像。这是本文的主要贡献。它训练生成器以生成与原始图像相似的其他分布的图像。如果没有这种损失,可能会产生任何来自分发的损失。现在它需要从的分布中生成一些东西,但仍然具有的属性,这样才能重新生成类似的东西

是身份丢失。这被用来鼓励映射以保留输入和输出之间的颜色构成。

为了求解,鉴别器和应该在梯度上

这取决于对数似然损失。

为了稳定训练,负对数似然目标被最小二乘损失所取代,即鉴别器的最小二乘误差,用1标记真实图像,将生成的图像标记为0。所以我们想在渐变上下降,

我们也使用最小二乘作为生成器。发电机应该下降到梯度上,

我们使用 fgenerator_xy or 和 fgenerator_yx or。我们使用 fdiscriminator_x or 和 fdiscriminator_y or

439    def run(self):

重播缓冲区以保留生成的样本

541        gen_x_buffer = ReplayBuffer()
542        gen_y_buffer = ReplayBuffer()

循环穿越时代

545        for epoch in monit.loop(self.epochs):

循环浏览数据集

547            for i, batch in monit.enum('Train', self.dataloader):

将图像移动到设备

549                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

真实标签等于

552                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553                                         device=self.device, requires_grad=False)

假标签等于

555                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556                                           device=self.device, requires_grad=False)

训练发电机。这将返回生成的图像。

560                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

训练鉴别器

563                self.optimize_discriminator(data_x, data_y,
564                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565                                            true_labels, false_labels)

保存训练统计数据并增加全局步数计数器

568                tracker.save()
569                tracker.add_global_step(max(len(data_x), len(data_y)))

每隔一段时间保存图像

572                batches_done = epoch * len(self.dataloader) + i
573                if batches_done % self.sample_interval == 0:

采样图像时保存模型

575                    experiment.save_checkpoint()

样本图片

577                    self.sample_images(batches_done)

更新学习率

580            self.generator_lr_scheduler.step()
581            self.discriminator_lr_scheduler.step()

新产品线

583            tracker.new_line()
利用@@

标识、增益和循环损耗来优化发电机。

585    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

改为训练模式

591        self.generator_xy.train()
592        self.generator_yx.train()

身份丢失

597        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598                         self.identity_loss(self.generator_xy(data_y), data_y))

生成图像

601        gen_y = self.generator_xy(data_x)
602        gen_x = self.generator_yx(data_y)

GAN 损失

607        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

周期损失

615        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616                      self.cycle_loss(self.generator_xy(gen_x), data_y))

总亏损

619        loss_generator = (loss_gan +
620                          self.cyclic_loss_coefficient * loss_cycle +
621                          self.identity_loss_coefficient * loss_identity)

在优化器中迈出一步

624        self.generator_optimizer.zero_grad()
625        loss_generator.backward()
626        self.generator_optimizer.step()

对数损失

629        tracker.add({'loss.generator': loss_generator,
630                     'loss.generator.cycle': loss_cycle,
631                     'loss.generator.gan': loss_gan,
632                     'loss.generator.identity': loss_identity})

返回生成的图像

635        return gen_x, gen_y

利用 gan 损耗优化鉴别器。

637    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638                               gen_x: torch.Tensor, gen_y: torch.Tensor,
639                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN 损失

652        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
655                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

在优化器中迈出一步

658        self.discriminator_optimizer.zero_grad()
659        loss_discriminator.backward()
660        self.discriminator_optimizer.step()

对数损失

663        tracker.add({'loss.discriminator': loss_discriminator})

火车周期 GAN

666def train():

创建配置

671    conf = Configs()

创建实验

673    experiment.create(name='cycle_gan')

计算配置。它将计算conf.run 和它所需的所有其他配置。

676    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677    conf.initialize()

注册模型以进行保存和加载。get_modules 给出了nn.Modules in 的字典conf 。您还可以指定模型的自定义字典。

682    experiment.add_pytorch_models(get_modules(conf))

开始观看实验

684    with experiment.start():

运行训练

686        conf.run()

使用 matplotlib 绘制图像

689def plot_image(img: torch.Tensor):
693    from matplotlib import pyplot as plt

将张量移到 CPU

696    img = img.cpu()

获取图像的最小值和最大值以进行归一化

698    img_min, img_max = img.min(), img.max()

将图像值缩放为 0... 1

700    img = (img - img_min) / (img_max - img_min + 1e-5)

我们必须将尺寸顺序更改为 HWC。

702    img = img.permute(1, 2, 0)

显示图片

704    plt.imshow(img)

我们不需要斧头

706    plt.axis('off')

显示

708    plt.show()

评估训练过的循环 GAN

711def evaluate():

设置训练跑的跑步 UUID

716    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

创建配置对象

718    conf = Configs()

创建实验

720    experiment.create(name='cycle_gan_inference')

加载为训练设置的超级参数

722    conf_dict = experiment.load_configs(trained_run_uuid)

计算配置。我们指定生成器,'generator_xy', 'generator_yx' 以便它只加载这些生成器及其依赖项。img_channels 将计算device 和之类的配置,因为generator_xy 和需要这些配置generator_yx

如果你想要其他参数,dataset_name 你应该在这里指定它们。如果未指定任何内容,则将计算所有配置,包括数据加载器。调用时将计算配置及其依赖关系experiment.start

731    experiment.configs(conf, conf_dict)
732    conf.initialize()

注册模型以进行保存和加载。get_modules 给出了nn.Modules in 的字典conf 。您还可以指定模型的自定义字典。

737    experiment.add_pytorch_models(get_modules(conf))

指定要从哪个运行中加载。当你打电话时,加载实际上会发生experiment.start

740    experiment.load(trained_run_uuid)

开始实验

743    with experiment.start():

图像变换

745        transforms_ = [
746            transforms.ToTensor(),
747            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748        ]

加载您自己的数据。在这里,我们尝试测试集。我在尝试优胜美地的照片,它们看起来很棒。你可以使用conf.dataset_name ,如果你指定dataset_name 为你想要在调用中计算的东西experiment.configs

754        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

从数据集中获取图像

756        x_image = dataset[10]['x']

显示图像

758        plot_image(x_image)

评估模式

761        conf.generator_xy.eval()
762        conf.generator_yx.eval()

我们不需要渐变

765        with torch.no_grad():

添加批量维度并移动到我们使用的设备

767            data = x_image.unsqueeze(0).to(conf.device)
768            generated_y = conf.generator_xy(data)

显示生成的图像。

771        plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775    train()

评估 ()