这是《分析和提高 StyleGan 的图像质量》一文的 PyTorch 实现,该论文介绍了 StyleGan 2。StyleGan 2 是对论文《生成对抗网络的基于样式的生成器架构》中对 StyleG an 的改进。StyleG an 基于论文《逐步生长 GaN 以提高质量、稳定性和变异性》中的渐进式 GAN。这三篇论文均出自 NVIDIA AI 的同一位作者。
我们的实现是一个简约的 StyleGan 2 模型训练代码。仅支持单个 GPU 训练,以保持实现简单。我们设法缩小了它,使其保持在不到 500 行代码中,包括训练循环。
🏃 这里是训练代码:experiment.py
。
这些是在训练了大约 80K 步之后生成的图像。
我们将首先对这三篇论文进行较高层次的介绍。
生成对抗网络有两个组成部分:生成器和鉴别器。生成器网络采用随机潜向量 () 并尝试生成逼真的图像。鉴别器网络试图将真实图像与生成的图像区分开来。当我们一起训练两个网络时,生成器开始生成与真实图像没有区别的图像。
渐进式 GAN 生成大小为的高分辨率图像 ()。它通过逐步增加图像大小来做到这一点。首先,它训练一个网络,该网络生成图像,然后生成图像,依此类推,直至所需的图像分辨率。
在每种分辨率下,生成器网络都会在潜空间中生成一张图像,然后将其转换为具有卷积的 RGB。当我们从较低的分辨率发展到更高的分辨率(比如从到)时,我们会缩放潜在图像并添加一个新块(两个卷积层)和一个用于获得 RGB 的新图层。通过在缩放的 RGB图像上添加残余连接,可以顺利完成过渡。这个剩余连接的重量会慢慢减轻,让新块接管。
鉴别器是发电机网络的镜像。鉴别器的渐进增长也是类似的。
和表示要素地图分辨率的缩放和缩放。、、... 表示生成器或鉴别器块处的特征图分辨率。每个鉴别器和生成器模块由2个卷积层组成,RelU激活泄漏。
他们使用 minibatch标准差来增加变异和均衡学习率,我们在下文的实现中对此进行了讨论。它们还使用逐像素归一化,其中特征向量在每个像素处进行归一化。它们将其应用于所有卷积层输出(RGB 除外)。
StyleGan 改进了 Progressive GAN 的生成器,使鉴别器架构保持不变。
它将随机潜在向量 () 映射到另一个具有8层神经网络的潜在空间 () 中。这给出了一个中间的潜在空间,其中变化的因子更加线性(解开)。
然后将每个图层转换为两个矢量(样式),并用于在每个图层中进行缩放和移动(偏置)运算符(归一化和缩放):
为了防止生成器假设相邻样式是相关的,它们会随机对不同的块使用不同的样式。也就是说,他们对两个潜在向量进行采样,对某些块进行对应和使用基于样式,对某些块使用基于样式随机黑人。
噪点可用于每个方块,这有助于生成器创建更逼真的图像。噪声按学习的权重按每个通道进行缩放。
所有向上和向下采样操作都伴随着双线性平滑。
表示线性层。表示广播和缩放操作(噪声是单个信道)。StyleGan 还使用渐进式 GAN 等渐进式增长。
StyleGan 2 同时更改了 StyleGan 的生成器和鉴别器。
他们将操作员移除,并用权重调制和解调步骤代替它。这应该改善他们所谓的液滴伪像,这些伪影存在于生成的图像中,这是由运算符中的归一化引起的。每个图层的样式向量是根据计算得出的。
然后按如下方式调制卷积权重。(这里指的是权重而不是中间的潜在空间,我们坚持使用与纸张相同的符号。)
然后通过归一化进行解调,其中是输入通道,是输出通道,是内核索引。
路径长度正则化鼓励采用固定大小的步进,从而在生成的图像中产生非零的固定幅度变化。
StyleGan2在鉴别器中使用残差连接(带下采样),并通过上采样跳过生成器中的连接(添加了每个图层的RGB输出-特征图中没有残余连接)。他们表明,通过实验,在训练开始时,低分辨率图层的贡献更高,然后高分辨率图层接管。
148import math
149from typing import Tuple, Optional, List
150
151import numpy as np
152import torch
153import torch.nn.functional as F
154import torch.utils.data
155from torch import nn
158class MappingNetwork(nn.Module):
features
是和中的要素数量n_layers
是制图网络中的层数。173 def __init__(self, features: int, n_layers: int):
178 super().__init__()
创建 MLP
181 layers = []
182 for i in range(n_layers):
Leaky Relu
186 layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
187
188 self.net = nn.Sequential(*layers)
190 def forward(self, z: torch.Tensor):
规范化
192 z = F.normalize(z, dim=1)
映射到
194 return self.net(z)
表示线性层。表示广播和缩放操作(噪声是单个信道)。toRGB
还有一种风格调制,为了简单起见,图中没有显示这种调制。
生成器以学习的常数开始。然后它有一系列方块。每个区块的要素图分辨率加倍。每个模块输出一个 RGB 图像,然后放大和求和以获得最终的 RGB 图像。
197class Generator(nn.Module):
log_resolution
是图像分辨率的d_latent
是的维度n_features
卷积层中分辨率最高的要素数(最终块)max_features
任何发电机组中要素的最大数目214 def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512):
221 super().__init__()
226 features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]
发电机组数量
228 self.n_blocks = len(features)
可训练常数
231 self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))
分辨率和图层的第一个样式块来获得 RGB
234 self.style_block = StyleBlock(d_latent, features[0], features[0])
235 self.to_rgb = ToRGB(d_latent, features[0])
发电机块
238 blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
239 self.blocks = nn.ModuleList(blocks)
向上采样层。特征空间在每个区块向上采样
243 self.up_sample = UpSample()
w
是。为了混合样式(对不同的层使用不同的样式),我们为每个生成器模块提供了单独的样式。它有形状[n_blocks, batch_size, d_latent]
。input_noise
是每个方块的噪声。这是一对噪声传感器的列表,因为每个模块(初始模块除外)在每个卷积层之后都有两个噪声输入(参见图表)。245 def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
获取批次大小
255 batch_size = w.shape[1]
展开学习的常量以匹配批次大小
258 x = self.initial_constant.expand(batch_size, -1, -1, -1)
第一个样式方块
261 x = self.style_block(x, w[0], input_noise[0][1])
获取第一张 rgb 图像
263 rgb = self.to_rgb(x, w[0])
评估其余的区块
266 for i in range(1, self.n_blocks):
向上采样要素地图
268 x = self.up_sample(x)
向上采样 RGB 图像并从方块中添加到 rgb
272 rgb = self.up_sample(rgb) + rgb_new
返回最终的 RGB 图像
275 return rgb
278class GeneratorBlock(nn.Module):
d_latent
是的维度in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数294 def __init__(self, d_latent: int, in_features: int, out_features: int):
300 super().__init__()
303 self.style_block1 = StyleBlock(d_latent, in_features, out_features)
torGB 层
308 self.to_rgb = ToRGB(d_latent, out_features)
x
是形状的输入要素地图[batch_size, in_features, height, width]
w
是有形状的[batch_size, d_latent]
noise
是由两个形状的噪声张量组成的元组[batch_size, 1, height, width]
310 def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]):
第一个带有第一个噪声张量的样式块。输出是形状的[batch_size, out_features, height, width]
318 x = self.style_block1(x, w, noise[0])
具有第二个噪声张量的第二样式块。输出是形状的[batch_size, out_features, height, width]
321 x = self.style_block2(x, w, noise[1])
获取 RGB 图像
324 rgb = self.to_rgb(x, w)
返回特征图和 rgb 图像
327 return x, rgb
330class StyleBlock(nn.Module):
d_latent
是的维度in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数344 def __init__(self, d_latent: int, in_features: int, out_features: int):
350 super().__init__()
353 self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0)
权重调制卷积层
355 self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)
噪音标度
357 self.scale_noise = nn.Parameter(torch.zeros(1))
偏见
359 self.bias = nn.Parameter(torch.zeros(out_features))
激活功能
362 self.activation = nn.LeakyReLU(0.2, True)
x
是形状的输入要素地图[batch_size, in_features, height, width]
w
是有形状的[batch_size, d_latent]
noise
是形状张量[batch_size, 1, height, width]
364 def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]):
获取样式矢量
371 s = self.to_style(w)
权重调制卷积
373 x = self.conv(x, s)
缩放和添加噪点
375 if noise is not None:
376 x = x + self.scale_noise[None, :, None, None] * noise
添加偏差并评估激活函数
378 return self.activation(x + self.bias[None, :, None, None])
381class ToRGB(nn.Module):
d_latent
是的维度features
是要素地图中的要素数量394 def __init__(self, d_latent: int, features: int):
399 super().__init__()
402 self.to_style = EqualizedLinear(d_latent, features, bias=1.0)
没有解调的权重调制卷积层
405 self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)
偏见
407 self.bias = nn.Parameter(torch.zeros(3))
激活功能
409 self.activation = nn.LeakyReLU(0.2, True)
411 def forward(self, x: torch.Tensor, w: torch.Tensor):
获取样式矢量
417 style = self.to_style(w)
权重调制卷积
419 x = self.conv(x, style)
添加偏差并评估激活函数
421 return self.activation(x + self.bias[None, :, None, None])
424class Conv2dWeightModulate(nn.Module):
in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数kernel_size
是卷积内核的大小demodulate
是标志是否根据权重的标准差归一化权重eps
是用于规范化的431 def __init__(self, in_features: int, out_features: int, kernel_size: int,
432 demodulate: float = True, eps: float = 1e-8):
440 super().__init__()
输出要素的数量
442 self.out_features = out_features
是否规格化权重
444 self.demodulate = demodulate
填充大小
446 self.padding = (kernel_size - 1) // 2
449 self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
451 self.eps = eps
x
是形状的输入要素地图[batch_size, in_features, height, width]
s
是基于样式的形状缩放张量[batch_size, in_features]
453 def forward(self, x: torch.Tensor, s: torch.Tensor):
获取批次大小、高度和宽度
460 b, _, h, w = x.shape
重塑天平
463 s = s[:, None, :, None, None]
470 weights = weights * s
解调
473 if self.demodulate:
475 sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
477 weights = weights * sigma_inv
重塑x
480 x = x.reshape(1, -1, h, w)
重塑权重
483 _, _, *ws = weights.shape
484 weights = weights.reshape(b * self.out_features, *ws)
使用分组卷积来使用样本明智的内核有效地计算卷积。也就是说,我们在批处理中的每个样本都有不同的内核(权重)
488 x = F.conv2d(x, weights, padding=self.padding, groups=b)
重塑x
为[batch_size, out_features, height, width]
然后返回
491 return x.reshape(-1, self.out_features, h, w)
494class Discriminator(nn.Module):
log_resolution
是图像分辨率的n_features
卷积层中分辨率最高的要素数(第一个块)max_features
任何发电机组中要素的最大数目508 def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512):
514 super().__init__()
用于将 RGB 图像转换为具有n_features
多个要素的要素地图的图层。
517 self.from_rgb = nn.Sequential(
518 EqualizedConv2d(3, n_features, 1),
519 nn.LeakyReLU(0.2, True),
520 )
525 features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]
鉴别器块
529 blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
530 self.blocks = nn.Sequential(*blocks)
添加标准差地图后的要素数量
535 final_features = features[-1] + 1
最终卷积层
537 self.conv = EqualizedConv2d(final_features, final_features, 3)
获得分类的最终线性层
539 self.final = EqualizedLinear(2 * 2 * final_features, 1)
x
是形状的输入图像[batch_size, 3, height, width]
541 def forward(self, x: torch.Tensor):
尝试规范化图像(这完全是可选的,但稍微加快了早期训练)
547 x = x - 0.5
从 RGB 进行转换
549 x = self.from_rgb(x)
卷积
556 x = self.conv(x)
压平
558 x = x.reshape(x.shape[0], -1)
返回分类分数
560 return self.final(x)
563class DiscriminatorBlock(nn.Module):
in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数574 def __init__(self, in_features, out_features):
579 super().__init__()
剩余连接的下采样和卷积层
581 self.residual = nn.Sequential(DownSample(),
582 EqualizedConv2d(in_features, out_features, kernel_size=1))
两次卷积
585 self.block = nn.Sequential(
586 EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
587 nn.LeakyReLU(0.2, True),
588 EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
589 nn.LeakyReLU(0.2, True),
590 )
向下采样层
593 self.down_sample = DownSample()
添加残差后的缩放系数
596 self.scale = 1 / math.sqrt(2)
598 def forward(self, x):
获取剩余连接
600 residual = self.residual(x)
卷积
603 x = self.block(x)
向下采样
605 x = self.down_sample(x)
添加残差和比例
608 return (x + residual) * self.scale
611class MiniBatchStdDev(nn.Module):
group_size
是要计算标准差的样本数。623 def __init__(self, group_size: int = 4):
627 super().__init__()
628 self.group_size = group_size
x
是要素地图630 def forward(self, x: torch.Tensor):
检查批次大小是否可以被组大小整除
635 assert x.shape[0] % self.group_size == 0
将样本分成几组group_size
,我们将特征图展平为单个维度,因为我们要计算每个要素的标准差。
638 grouped = x.view(self.group_size, -1)
645 std = torch.sqrt(grouped.var(dim=0) + 1e-8)
获取平均标准差
647 std = std.mean().view(1, 1, 1, 1)
展开要追加到要素地图的标准差
649 b, _, h, w = x.shape
650 std = std.expand(b, -1, h, w)
将标准差追加(连接)到要素地图
652 return torch.cat([x, std], dim=1)
655class DownSample(nn.Module):
667 def __init__(self):
668 super().__init__()
平滑层
670 self.smooth = Smooth()
672 def forward(self, x: torch.Tensor):
平滑或模糊
674 x = self.smooth(x)
缩小规模
676 return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False)
679class UpSample(nn.Module):
690 def __init__(self):
691 super().__init__()
向上采样层
693 self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
平滑层
695 self.smooth = Smooth()
697 def forward(self, x: torch.Tensor):
向上采样和平滑
699 return self.smooth(self.up_sample(x))
702class Smooth(nn.Module):
711 def __init__(self):
712 super().__init__()
模糊内核
714 kernel = [[1, 2, 1],
715 [2, 4, 2],
716 [1, 2, 1]]
将内核转换为 PyTorch 张量
718 kernel = torch.tensor([[kernel]], dtype=torch.float)
规范化内核
720 kernel /= kernel.sum()
将内核另存为固定参数(不更新渐变)
722 self.kernel = nn.Parameter(kernel, requires_grad=False)
填充层
724 self.pad = nn.ReplicationPad2d(1)
726 def forward(self, x: torch.Tensor):
获取输入要素地图的形状
728 b, c, h, w = x.shape
重塑以实现平滑
730 x = x.view(-1, 1, h, w)
添加填充
733 x = self.pad(x)
使用内核平滑(模糊)
736 x = F.conv2d(x, self.kernel)
重塑并返回
739 return x.view(b, c, h, w)
in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数bias
是偏置初始化常数751 def __init__(self, in_features: int, out_features: int, bias: float = 0.):
758 super().__init__()
偏见
762 self.bias = nn.Parameter(torch.ones(out_features) * bias)
764 def forward(self, x: torch.Tensor):
线性变换
766 return F.linear(x, self.weight(), bias=self.bias)
in_features
是输入要素地图中的要素数out_features
是输出要素地图中的要素数kernel_size
是卷积内核的大小padding
是要在每个尺寸维度的两边添加的内边距778 def __init__(self, in_features: int, out_features: int,
779 kernel_size: int, padding: int = 0):
786 super().__init__()
填充大小
788 self.padding = padding
偏见
792 self.bias = nn.Parameter(torch.ones(out_features))
794 def forward(self, x: torch.Tensor):
卷积
796 return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)
这是基于 Progressive GAN 论文中介绍的均衡学习率。它们不是在初始化权重,而是将权重初始化为,然后在使用时将其乘以。
存储参数的梯度会被乘以,但这不会产生影响,因为像 Adam 这样的优化器将它们归一化为梯度的平方。
上的优化器更新与学习速率成正比。但是有效权重会按比例更新。如果没有均衡的学习率,有效权重将按比例更新为 just。
因此,我们正在有效地缩放这些权重参数的学习速率。
799class EqualizedWeight(nn.Module):
shape
是权重参数的形状820 def __init__(self, shape: List[int]):
824 super().__init__()
他初始化常量
827 self.c = 1 / math.sqrt(np.prod(shape[1:]))
使用初始化权重
829 self.weight = nn.Parameter(torch.randn(shape))
权重乘法系数
832 def forward(self):
将权重乘以并返回
834 return self.weight * self.c
837class GradientPenalty(nn.Module):
853 def forward(self, x: torch.Tensor, d: torch.Tensor):
获取批次大小
860 batch_size = x.shape[0]
计算相对于的梯度。grad_outputs
设置为,因为我们想要梯度,并且我们需要创建和保留图形,因为我们必须计算相对于此损失的权重的梯度。
866 gradients, *_ = torch.autograd.grad(outputs=d,
867 inputs=x,
868 grad_outputs=d.new_ones(d.shape),
869 create_graph=True)
重塑梯度以计算范数
872 gradients = gradients.reshape(batch_size, -1)
计算常数
874 norm = gradients.norm(2, dim=-1)
退还损失
876 return torch.mean(norm ** 2)
这种正则化鼓励采用固定大小的步进,从而导致图像中的固定幅度变化。
其中是 Jacobian,是从测绘网络中采样的,并且是有噪点的图像。
是训练进行时的指数移动平均线。
计算时未使用显式计算雅可比式
879class PathLengthPenalty(nn.Module):
beta
是用于计算指数移动平均线的常数903 def __init__(self, beta: float):
907 super().__init__()
910 self.beta = beta
计算的步数
912 self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)
训练第-步,其中是它的值的指数和
916 self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)
w
是形状的批次[batch_size, d_latent]
x
是生成的形状图像[batch_size, 3, height, width]
918 def forward(self, w: torch.Tensor, x: torch.Tensor):
拿到设备
925 device = x.device
获取像素数
927 image_size = x.shape[2] * x.shape[3]
计算
929 y = torch.randn(x.shape, device=device)
计算梯度以获取
936 gradients, *_ = torch.autograd.grad(outputs=output,
937 inputs=w,
938 grad_outputs=torch.ones(output.shape, device=device),
939 create_graph=True)
计算 L2 范数
942 norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()
第一步后正规化
945 if self.steps > 0:
计算
948 a = self.exp_sum_a / (1 - self.beta ** self.steps)
计算罚款
952 loss = torch.mean((norm - a) ** 2)
953 else:
如果我们无法计算,则返回虚拟损失
955 loss = norm.new_tensor(0)
计算的均值
958 mean = norm.mean().detach()
更新指数和
960 self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)
增量
962 self.steps.add_(1.)
退还罚款
965 return loss