这是 StyleGan 2 模型的训练代码。
这些是在训练了大约 80K 步之后生成的图像。
我们的实现是一个简约的 StyleGan 2 模型训练代码。仅支持单个 GPU 训练,以保持实现简单。我们设法缩小了它,使其保持在不到 500 行代码中,包括训练循环。
如果没有 DDP(分布式数据并行)和多 GPU 训练,将无法为大分辨率(128+)训练模型。如果你想用 fp16 和 DDP 训练代码,可以看看 l ucidrains/stylegan2-pytorch。
我们在 Celeba-HQ 数据集上训练了这个。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/stylegan
文件夹中。
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader
49class Dataset(torch.utils.data.Dataset):
path
包含图像的文件夹的路径image_size
图像的大小56 def __init__(self, path: str, image_size: int):
61 super().__init__()
获取所有jpg
文件的路径
64 self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
转型
67 self.transform = torchvision.transforms.Compose([
调整图像大小
69 torchvision.transforms.Resize(image_size),
转换为 pyTorch 张量
71 torchvision.transforms.ToTensor(),
72 ])
图像数量
74 def __len__(self):
76 return len(self.paths)
获取第index
-th 张图片
78 def __getitem__(self, index):
80 path = self.paths[index]
81 img = Image.open(path)
82 return self.transform(img)
85class Configs(BaseConfigs):
用于训练模型的设备。DeviceConfigs
选择可用的 CUDA 设备或默认为 CPU。
93 device: torch.device = DeviceConfigs()
96 discriminator: Discriminator
98 generator: Generator
鉴别器和发生器损耗函数。我们使用 Wasserstein 的损失
104 discriminator_loss: DiscriminatorLoss
105 generator_loss: GeneratorLoss
优化器
108 generator_optimizer: torch.optim.Adam
109 discriminator_optimizer: torch.optim.Adam
110 mapping_network_optimizer: torch.optim.Adam
梯度惩罚系数
115 gradient_penalty_coefficient: float = 10.
数据加载器
121 loader: Iterator
批量大小
124 batch_size: int = 32
和的维度
126 d_latent: int = 512
图像的高度/宽度
128 image_size: int = 32
制图网络中的图层数
130 mapping_network_layers: int = 8
生成器和鉴别器学习速率
132 learning_rate: float = 1e-3
映射网络学习率(低于其他)
134 mapping_network_learning_rate: float = 1e-5
累积梯度的步数。使用它可以增加有效批次大小。
136 gradient_accumulate_steps: int = 1
对于 Adam 优化器来说
138 adam_betas: Tuple[float, float] = (0.0, 0.99)
混合样式的概率
140 style_mixing_prob: float = 0.9
训练步数总数
143 training_steps: int = 150_000
生成器中的块数(根据图像分辨率计算)
146 n_gen_blocks: int
计算梯度惩罚的间隔
154 lazy_gradient_penalty_interval: int = 4
路径长度惩罚计算间隔
156 lazy_path_penalty_interval: int = 32
在训练的初始阶段跳过计算路径长度损失
158 lazy_path_penalty_after: int = 5_000
记录生成的图像的频率
161 log_generated_interval: int = 500
保存模型检查点的频率
163 save_checkpoint_interval: int = 2_000
日志记录激活的训练模式状态
166 mode: ModeState
是否记录模型层输出
168 log_layer_outputs: bool = False
我们在 Celeba-HQ 数据集上训练了这个。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/stylegan
文件夹中。
175 dataset_path: str = str(lab.get_data_path() / 'stylegan2')
177 def init(self):
创建数据集
182 dataset = Dataset(self.dataset_path, self.image_size)
创建数据加载器
184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185 shuffle=True, drop_last=True, pin_memory=True)
的图像分辨率
190 log_resolution = int(math.log2(self.image_size))
创建鉴别器和生成器
193 self.discriminator = Discriminator(log_resolution).to(self.device)
194 self.generator = Generator(log_resolution, self.d_latent).to(self.device)
获取用于创建样式和噪声输入的生成器模块的数量
196 self.n_gen_blocks = self.generator.n_blocks
创建测绘网络
198 self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
创建路径长度惩罚损失
200 self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
添加模型挂接以监视层输出
203 if self.log_layer_outputs:
204 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205 hook_model_outputs(self.mode, self.generator, 'generator')
206 hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
鉴别器和发电机损耗
209 self.discriminator_loss = DiscriminatorLoss().to(self.device)
210 self.generator_loss = GeneratorLoss().to(self.device)
创建优化器
213 self.discriminator_optimizer = torch.optim.Adam(
214 self.discriminator.parameters(),
215 lr=self.learning_rate, betas=self.adam_betas
216 )
217 self.generator_optimizer = torch.optim.Adam(
218 self.generator.parameters(),
219 lr=self.learning_rate, betas=self.adam_betas
220 )
221 self.mapping_network_optimizer = torch.optim.Adam(
222 self.mapping_network.parameters(),
223 lr=self.mapping_network_learning_rate, betas=self.adam_betas
224 )
设置跟踪器配置
227 tracker.set_image("generated", True)
229 def get_w(self, batch_size: int):
混合风格
243 if torch.rand(()).item() < self.style_mixing_prob:
随机交叉点
245 cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
样本和
247 z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248 z1 = torch.randn(batch_size, self.d_latent).to(self.device)
获取和
250 w1 = self.mapping_network(z1)
251 w2 = self.mapping_network(z2)
展开 and for 生成器块并连接
253 w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254 w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255 return torch.cat((w1, w2), dim=0)
不混合
257 else:
样本和
259 z = torch.randn(batch_size, self.d_latent).to(self.device)
获取和
261 w = self.mapping_network(z)
为发电机组展开
263 return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
存储噪音的列表
272 noise = []
噪声分辨率从
274 resolution = 4
为每个发电机组生成噪声
277 for i in range(self.n_gen_blocks):
第一个方块只有一个卷积
279 if i == 0:
280 n1 = None
生成要在第一个卷积层之后添加的噪波
282 else:
283 n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
生成要在第二个卷积层之后添加的噪波
285 n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
将噪声张量添加到列表中
288 noise.append((n1, n2))
下一个区块有分辨率
291 resolution *= 2
返回噪声张量
294 return noise
296 def generate_images(self, batch_size: int):
得到
304 w = self.get_w(batch_size)
得到噪音
306 noise = self.get_noise(batch_size)
生成图像
309 images = self.generator(w, noise)
返回图像和
312 return images, w
314 def step(self, idx: int):
训练鉴别器
320 with monit.section('Discriminator'):
重置渐变
322 self.discriminator_optimizer.zero_grad()
累积梯度gradient_accumulate_steps
325 for i in range(self.gradient_accumulate_steps):
更新mode
。设置是否记录激活
327 with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
来自生成器的样本图像
329 generated_images, _ = self.generate_images(self.batch_size)
生成图像的鉴别器分类
331 fake_output = self.discriminator(generated_images.detach())
从数据加载器获取真实图像
334 real_images = next(self.loader).to(self.device)
我们需要用真实图像计算梯度以获得梯度惩罚
336 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337 real_images.requires_grad_()
真实图像的鉴别器分类
339 real_output = self.discriminator(real_images)
获得鉴别器损失
342 real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343 disc_loss = real_loss + fake_loss
添加渐变惩罚
346 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
计算并记录梯度损失
348 gp = self.gradient_penalty(real_images, real_output)
349 tracker.add('loss.gp', gp)
乘以系数并添加梯度惩罚
351 disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
计算梯度
354 disc_loss.backward()
日志鉴别器丢失
357 tracker.add('loss.discriminator', disc_loss)
358
359 if (idx + 1) % self.log_generated_interval == 0:
偶尔记录鉴别器模型参数
361 tracker.add('discriminator', self.discriminator)
用于稳定的剪辑渐变
364 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
采取优化器步骤
366 self.discriminator_optimizer.step()
训练发电机
369 with monit.section('Generator'):
重置渐变
371 self.generator_optimizer.zero_grad()
372 self.mapping_network_optimizer.zero_grad()
累积梯度gradient_accumulate_steps
375 for i in range(self.gradient_accumulate_steps):
来自生成器的样本图像
377 generated_images, w = self.generate_images(self.batch_size)
生成图像的鉴别器分类
379 fake_output = self.discriminator(generated_images)
获得发电机损失
382 gen_loss = self.generator_loss(fake_output)
增加路径长度惩罚
385 if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
计算路径长度损失
387 plp = self.path_length_penalty(w, generated_images)
忽略如果nan
389 if not torch.isnan(plp):
390 tracker.add('loss.plp', plp)
391 gen_loss = gen_loss + plp
计算梯度
394 gen_loss.backward()
日志生成器丢失
397 tracker.add('loss.generator', gen_loss)
398
399 if (idx + 1) % self.log_generated_interval == 0:
偶尔记录鉴别器模型参数
401 tracker.add('generator', self.generator)
402 tracker.add('mapping_network', self.mapping_network)
用于稳定的剪辑渐变
405 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406 torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
采取优化器步骤
409 self.generator_optimizer.step()
410 self.mapping_network_optimizer.step()
日志生成的图像
413 if (idx + 1) % self.log_generated_interval == 0:
414 tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
保存模型检查点
416 if (idx + 1) % self.save_checkpoint_interval == 0:
417 experiment.save_checkpoint()
冲洗追踪器
420 tracker.save()
422 def train(self):
循环寻回training_steps
428 for i in monit.loop(self.training_steps):
迈出训练一步
430 self.step(i)
432 if (i + 1) % self.log_generated_interval == 0:
433 tracker.new_line()
436def main():
创建实验
442 experiment.create(name='stylegan2')
创建配置对象
444 configs = Configs()
设置配置并覆盖一些
447 experiment.configs(configs, {
448 'device.cuda_device': 0,
449 'image_size': 64,
450 'log_generated_interval': 200
451 })
初始化
454 configs.init()
设置用于保存和加载的模型
456 experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457 generator=configs.generator,
458 discriminator=configs.discriminator)
开始实验
461 with experiment.start():
运行训练循环
463 configs.train()
467if __name__ == '__main__':
468 main()