StyleGan 2 模型训练

这是 StyleGan 2 模型的训练代码。

Generated Images

这些是在训练了大约 80K 步之后生成的图像。

我们的实现是一个简约的 StyleGan 2 模型训练代码。仅支持单个 GPU 训练,以保持实现简单。我们设法缩小了它,使其保持在不到 500 行代码中,包括训练循环。

如果没有 DDP(分布式数据并行)和多 GPU 训练,将无法为大分辨率(128+)训练模型。如果你想用 fp16 和 DDP 训练代码,可以看看 l ucidrains/stylegan2-pytorch

我们在 Celeba-HQ 数据集上训练了这个。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/stylegan 文件夹中

31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader

数据集

这将加载训练数据集并将其调整为给定的图像大小。

49class Dataset(torch.utils.data.Dataset):
  • path 包含图像的文件夹的路径
  • image_size 图像的大小
56    def __init__(self, path: str, image_size: int):
61        super().__init__()

获取所有jpg 文件的路径

64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

转型

67        self.transform = torchvision.transforms.Compose([

调整图像大小

69            torchvision.transforms.Resize(image_size),

转换为 pyTorch 张量

71            torchvision.transforms.ToTensor(),
72        ])

图像数量

74    def __len__(self):
76        return len(self.paths)

获取第index -th 张图片

78    def __getitem__(self, index):
80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)

配置

85class Configs(BaseConfigs):

用于训练模型的设备。DeviceConfigs 选择可用的 CUDA 设备或默认为 CPU。

93    device: torch.device = DeviceConfigs()
96    discriminator: Discriminator
98    generator: Generator
100    mapping_network: MappingNetwork

鉴别器和发生器损耗函数。我们使用 Wasserstein 的损失

104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLoss

优化器

108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam
113    gradient_penalty = GradientPenalty()

梯度惩罚系数

115    gradient_penalty_coefficient: float = 10.
118    path_length_penalty: PathLengthPenalty

数据加载器

121    loader: Iterator

批量大小

124    batch_size: int = 32

和的维度

126    d_latent: int = 512

图像的高度/宽度

128    image_size: int = 32

制图网络中的图层数

130    mapping_network_layers: int = 8

生成器和鉴别器学习速率

132    learning_rate: float = 1e-3

映射网络学习率(低于其他)

134    mapping_network_learning_rate: float = 1e-5

累积梯度的步数。使用它可以增加有效批次大小。

136    gradient_accumulate_steps: int = 1

于 Adam 优化器来说

138    adam_betas: Tuple[float, float] = (0.0, 0.99)

混合样式的概率

140    style_mixing_prob: float = 0.9

训练步数总数

143    training_steps: int = 150_000

生成器中的块数(根据图像分辨率计算)

146    n_gen_blocks: int

懒惰正则化

本@@

文没有计算正则化损失,而是提出了懒惰的正则化,即偶尔计算一次正则化项。这大大提高了训练效率。

计算梯度惩罚的间隔

154    lazy_gradient_penalty_interval: int = 4

路径长度惩罚计算间隔

156    lazy_path_penalty_interval: int = 32

在训练的初始阶段跳过计算路径长度损失

158    lazy_path_penalty_after: int = 5_000

记录生成的图像的频率

161    log_generated_interval: int = 500

保存模型检查点的频率

163    save_checkpoint_interval: int = 2_000

日志记录激活的训练模式状态

166    mode: ModeState

是否记录模型层输出

168    log_layer_outputs: bool = False

我们在 Celeba-HQ 数据集上训练了这个。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/stylegan 文件夹中。

175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

初始化

177    def init(self):

创建数据集

182        dataset = Dataset(self.dataset_path, self.image_size)

创建数据加载器

184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185                                                 shuffle=True, drop_last=True, pin_memory=True)
187        self.loader = cycle_dataloader(dataloader)

的图像分辨率

190        log_resolution = int(math.log2(self.image_size))

创建鉴别器和生成器

193        self.discriminator = Discriminator(log_resolution).to(self.device)
194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)

获取用于创建样式和噪声输入的生成器模块的数量

196        self.n_gen_blocks = self.generator.n_blocks

创建测绘网络

198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)

创建路径长度惩罚损失

200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

添加模型挂接以监视层输出

203        if self.log_layer_outputs:
204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205            hook_model_outputs(self.mode, self.generator, 'generator')
206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')

鉴别器和发电机损耗

209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
210        self.generator_loss = GeneratorLoss().to(self.device)

创建优化器

213        self.discriminator_optimizer = torch.optim.Adam(
214            self.discriminator.parameters(),
215            lr=self.learning_rate, betas=self.adam_betas
216        )
217        self.generator_optimizer = torch.optim.Adam(
218            self.generator.parameters(),
219            lr=self.learning_rate, betas=self.adam_betas
220        )
221        self.mapping_network_optimizer = torch.optim.Adam(
222            self.mapping_network.parameters(),
223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
224        )

设置跟踪器配置

227        tracker.set_image("generated", True)

样本

这是随机采样并从映射网络中获取。

有时我们还会应用样式混合,我们生成两个潜在变量并得到相应的。然后我们随机采样一个交叉点,然后应用于交叉点之前的生成器方块和之后的区块。

229    def get_w(self, batch_size: int):

混合风格

243        if torch.rand(()).item() < self.style_mixing_prob:

随机交叉点

245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)

样本

247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)

获取

250            w1 = self.mapping_network(z1)
251            w2 = self.mapping_network(z2)

开 and for 生成器块并连接

253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255            return torch.cat((w1, w2), dim=0)

不混合

257        else:

样本

259            z = torch.randn(batch_size, self.d_latent).to(self.device)

获取

261            w = self.mapping_network(z)

为发电机组展开

263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)

产生噪音

这会为每个发电机组产生噪声

265    def get_noise(self, batch_size: int):

存储噪音的列表

272        noise = []

噪声分辨率从

274        resolution = 4

为每个发电机组生成噪声

277        for i in range(self.n_gen_blocks):

第一个方块只有一个卷积

279            if i == 0:
280                n1 = None

生成要在第一个卷积层之后添加的噪波

282            else:
283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

生成要在第二个卷积层之后添加的噪波

285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

将噪声张量添加到列表中

288            noise.append((n1, n2))

下一个区块有分辨率

291            resolution *= 2

返回噪声张量

294        return noise

生成图像

这会使用生成器生成图像

296    def generate_images(self, batch_size: int):

得到

304        w = self.get_w(batch_size)

得到噪音

306        noise = self.get_noise(batch_size)

生成图像

309        images = self.generator(w, noise)

返回图像和

312        return images, w

训练步骤

314    def step(self, idx: int):

训练鉴别器

320        with monit.section('Discriminator'):

重置渐变

322            self.discriminator_optimizer.zero_grad()

累积梯度gradient_accumulate_steps

325            for i in range(self.gradient_accumulate_steps):

更新mode 。设置是否记录激活

327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):

来自生成器的样本图像

329                    generated_images, _ = self.generate_images(self.batch_size)

生成图像的鉴别器分类

331                    fake_output = self.discriminator(generated_images.detach())

从数据加载器获取真实图像

334                    real_images = next(self.loader).to(self.device)

我们需要用真实图像计算梯度以获得梯度惩罚

336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337                        real_images.requires_grad_()

真实图像的鉴别器分类

339                    real_output = self.discriminator(real_images)

获得鉴别器损失

342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343                    disc_loss = real_loss + fake_loss

添加渐变惩罚

346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:

计算并记录梯度损失

348                        gp = self.gradient_penalty(real_images, real_output)
349                        tracker.add('loss.gp', gp)

乘以系数并添加梯度惩罚

351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

计算梯度

354                    disc_loss.backward()

日志鉴别器丢失

357                    tracker.add('loss.discriminator', disc_loss)
358
359            if (idx + 1) % self.log_generated_interval == 0:

偶尔记录鉴别器模型参数

361                tracker.add('discriminator', self.discriminator)

用于稳定的剪辑渐变

364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)

采取优化器步骤

366            self.discriminator_optimizer.step()

训练发电机

369        with monit.section('Generator'):

重置渐变

371            self.generator_optimizer.zero_grad()
372            self.mapping_network_optimizer.zero_grad()

累积梯度gradient_accumulate_steps

375            for i in range(self.gradient_accumulate_steps):

来自生成器的样本图像

377                generated_images, w = self.generate_images(self.batch_size)

生成图像的鉴别器分类

379                fake_output = self.discriminator(generated_images)

获得发电机损失

382                gen_loss = self.generator_loss(fake_output)

增加路径长度惩罚

385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:

计算路径长度损失

387                    plp = self.path_length_penalty(w, generated_images)

忽略如果nan

389                    if not torch.isnan(plp):
390                        tracker.add('loss.plp', plp)
391                        gen_loss = gen_loss + plp

计算梯度

394                gen_loss.backward()

日志生成器丢失

397                tracker.add('loss.generator', gen_loss)
398
399            if (idx + 1) % self.log_generated_interval == 0:

偶尔记录鉴别器模型参数

401                tracker.add('generator', self.generator)
402                tracker.add('mapping_network', self.mapping_network)

用于稳定的剪辑渐变

405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)

采取优化器步骤

409            self.generator_optimizer.step()
410            self.mapping_network_optimizer.step()

日志生成的图像

413        if (idx + 1) % self.log_generated_interval == 0:
414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))

保存模型检查点

416        if (idx + 1) % self.save_checkpoint_interval == 0:
417            experiment.save_checkpoint()

冲洗追踪器

420        tracker.save()

火车模型

422    def train(self):

循环寻回training_steps

428        for i in monit.loop(self.training_steps):

迈出训练一步

430            self.step(i)

432            if (i + 1) % self.log_generated_interval == 0:
433                tracker.new_line()

Train styleGan2

436def main():

创建实验

442    experiment.create(name='stylegan2')

创建配置对象

444    configs = Configs()

设置配置并覆盖一些

447    experiment.configs(configs, {
448        'device.cuda_device': 0,
449        'image_size': 64,
450        'log_generated_interval': 200
451    })

初始化

454    configs.init()

设置用于保存和加载的模型

456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457                                  generator=configs.generator,
458                                  discriminator=configs.discriminator)

开始实验

461    with experiment.start():

运行训练循环

463        configs.train()

467if __name__ == '__main__':
468    main()