StyleGan 2

这是《分析和提高 StyleGan 的图像质量》一文的 PyTorch 实现,该论文介绍了 StyleGan 2。StyleGan 2 是对论文《生成对抗网络的基于样式的生成器架构》中对 StyleG an 的改进。StyleG an 基于论文《逐步生长 GaN 以提高质量、稳定性和变异性》中的渐进式 GAN。这三篇论文均出自 NVIDIA AI 的同一位作者。

我们的实现是一个简约的 StyleGan 2 模型训练代码。仅支持单个 GPU 训练,以保持实现简单。我们设法缩小了它,使其保持在不到 500 行代码中,包括训练循环。

🏃 这里是训练代码:experiment.py

Generated Images

这些是在训练了大约 80K 步之后生成的图像。

我们将首先对这三篇论文进行较高层次的介绍。

生成对抗网络

生成对抗网络有两个组成部分:生成器和鉴别器。生成器网络采用随机潜向量 () 并尝试生成逼真的图像。鉴别器网络试图将真实图像与生成的图像区分开来。当我们一起训练两个网络时,生成器开始生成与真实图像没有区别的图像。

渐进式 GAN

渐进式 GAN 生成大小为的高分辨率图像 ()。它通过逐步增加图像大小来做到这一点。首先,它训练一个网络,该网络生成图像,然后生成图像,依此类推,直至所需的图像分辨率。

在每种分辨率下,生成器网络都会在潜空间中生成一张图像,然后将其转换为具有卷积的 RGB。当我们从较低的分辨率发展到更高的分辨率(比如从)时,我们会缩放潜在图像并添加一个新块(两个卷积层)和一个用于获得 RGB 的新图层。通过在缩放的 RGB图像上添加残余连接,可以顺利完成过渡。这个剩余连接的重量会慢慢减轻,让新块接管。

鉴别器是发电机网络的镜像。鉴别器的渐进增长也是类似的。

progressive_gan.svg

表示要素地图分辨率的缩放和缩放。、... 表示生成器或鉴别器块处的特征图分辨率。每个鉴别器和生成器模块由2个卷积层组成,RelU激活泄漏。

他们使用 minibatch标准差来增加变异和均衡学习率,我们在下文的实现中对此进行了讨论。它们还使用逐像素归一化,其中特征向量在每个像素处进行归一化。它们将其应用于所有卷积层输出(RGB 除外)。

StyleGan

StyleGan 改进了 Progressive GAN 的生成器,使鉴别器架构保持不变。

映射网络

它将随机潜在向量 () 映射到另一个具有8层神经网络的潜在空间 () 中。这给出了一个中间的潜在空间,其中变化的因子更加线性(解开)。

aDaIN

然后将每个图层转换为两个矢量(样式并用于在每个图层中进行缩放和移动(偏置)运算符(归一化和缩放):

风格混合

为了防止生成器假设相邻样式是相关的,它们会随机对不同的块使用不同的样式。也就是说,他们对两个潜在向量进行采样,对某些块进行对应和使用基于样式,对某些块使用基于样式随机黑人。

随机变异

噪点可用于每个方块,这有助于生成器创建更逼真的图像。噪声按学习的权重按每个通道进行缩放。

双线性上下采样

所有向上和向下采样操作都伴随着双线性平滑。

style_gan.svg

表示线性层。表示广播和缩放操作(噪声是单个信道)。StyleGan 还使用渐进式 GAN 等渐进式增长。

StyleGan 2

StyleGan 2 同时更改了 StyleGan 的生成器和鉴别器。

权重调制和解调

他们将操作员移除,并用权重调制和解调步骤代替它。这应该改善他们所谓的液滴伪像,这些伪影存在于生成的图像中,这是由运算符中的归一化引起的。每个图层的样式向量是根据计算得出

然后按如下方式调制卷积权重。(这里指的是权重而不是中间的潜在空间,我们坚持使用与纸张相同的符号。)

然后通过归一化进行解调,其中是输入通道,是输出通道,是内核索引。

路径长度正则化

路径长度正则化鼓励采用固定大小的步进,从而在生成的图像中产生非零的固定幅度变化。

没有渐进式增长

StyleGan2在鉴别器中使用残差连接(带下采样),并通过上采样跳过生成器中的连接(添加了每个图层的RGB输出-特征图中没有残余连接)。他们表明,通过实验,在训练开始时,低分辨率图层的贡献更高,然后高分辨率图层接管。

148import math
149from typing import Tuple, Optional, List
150
151import numpy as np
152import torch
153import torch.nn.functional as F
154import torch.utils.data
155from torch import nn

映射网络

Mapping Network

这是一个包含 8 个线性层的 MLP。映射网络将潜在向量映射到中间潜空间空间将与图像空间分开,在图像空间中,变异因子变得更加线性。

158class MappingNetwork(nn.Module):
  • features和中的要素数量
  • n_layers 是制图网络中的层数。
173    def __init__(self, features: int, n_layers: int):
178        super().__init__()

创建 MLP

181        layers = []
182        for i in range(n_layers):
184            layers.append(EqualizedLinear(features, features))

Leaky Relu

186            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
187
188        self.net = nn.Sequential(*layers)
190    def forward(self, z: torch.Tensor):

规范化

192        z = F.normalize(z, dim=1)

映射

194        return self.net(z)

StyleGan2 生成器

Generator

表示线性层。表示广播和缩放操作(噪声是单个信道)。toRGB 还有一种风格调制,为了简单起见,图中没有显示这种调制。

生成器以学习的常数开始。然后它有一系列方块。每个区块的要素图分辨率加倍。每个模块输出一个 RGB 图像,然后放大和求和以获得最终的 RGB 图像。

197class Generator(nn.Module):
  • log_resolution 是图像分辨率的
  • d_latent 是的维度
  • n_features 卷积层中分辨率最高的要素数(最终块)
  • max_features 任何发电机组中要素的最大数目
214    def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512):
221        super().__init__()

计算每个区块的要素数量

比如[512, 512, 256, 128, 64, 32]

226        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]

发电机组数量

228        self.n_blocks = len(features)

可训练常数

231        self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))

分辨率和图层的第一个样式块来获得 RGB

234        self.style_block = StyleBlock(d_latent, features[0], features[0])
235        self.to_rgb = ToRGB(d_latent, features[0])

发电机块

238        blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
239        self.blocks = nn.ModuleList(blocks)

向上采样层。特征空间在每个区块向上采样

243        self.up_sample = UpSample()
  • w。为了混合样式(对不同的层使用不同的样式),我们为每个生成器模块提供了单独的样式。它有形状[n_blocks, batch_size, d_latent]
  • input_noise 是每个方块的噪声。这是一对噪声传感器的列表,因为每个模块(初始模块除外)在每个卷积层之后都有两个噪声输入(参见图表)。
245    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):

获取批次大小

255        batch_size = w.shape[1]

展开学习的常量以匹配批次大小

258        x = self.initial_constant.expand(batch_size, -1, -1, -1)

第一个样式方块

261        x = self.style_block(x, w[0], input_noise[0][1])

获取第一张 rgb 图像

263        rgb = self.to_rgb(x, w[0])

评估其余的区块

266        for i in range(1, self.n_blocks):

向上采样要素地图

268            x = self.up_sample(x)

通过发电机组运行它

270            x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])

向上采样 RGB 图像并从方块中添加到 rgb

272            rgb = self.up_sample(rgb) + rgb_new

返回最终的 RGB 图像

275        return rgb

发电机组

Generator block

表示线性层。表示广播和缩放操作(噪声是单个信道)。toRGB 还有一种风格调制,为了简单起见,图中没有显示这种调制。

生成器模块由两个样式块(带样式调制的卷积)和一个 RGB 输出组成。

278class GeneratorBlock(nn.Module):
  • d_latent 是的维度
  • in_features 是输入要素地图中的要素数
  • out_features 是输出要素地图中的要素数
294    def __init__(self, d_latent: int, in_features: int, out_features: int):
300        super().__init__()

第一个样式块将要素地图大小更改为out_features

303        self.style_block1 = StyleBlock(d_latent, in_features, out_features)

第二种样式方块

305        self.style_block2 = StyleBlock(d_latent, out_features, out_features)

torGB

308        self.to_rgb = ToRGB(d_latent, out_features)
  • x 是形状的输入要素地图[batch_size, in_features, height, width]
  • w有形状的[batch_size, d_latent]
  • noise 是由两个形状的噪声张量组成的元组[batch_size, 1, height, width]
310    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]):

第一个带有第一个噪声张量的样式块。输出是形状的[batch_size, out_features, height, width]

318        x = self.style_block1(x, w, noise[0])

具有第二个噪声张量的第二样式块。输出是形状的[batch_size, out_features, height, width]

321        x = self.style_block2(x, w, noise[1])

获取 RGB 图像

324        rgb = self.to_rgb(x, w)

返回特征图和 rgb 图像

327        return x, rgb

样式方块

Style block

表示线性层。表示广播和缩放操作(噪声是单声道)。

样式块具有权重调制卷积层。

330class StyleBlock(nn.Module):
  • d_latent 是的维度
  • in_features 是输入要素地图中的要素数
  • out_features 是输出要素地图中的要素数
344    def __init__(self, d_latent: int, in_features: int, out_features: int):
350        super().__init__()
353        self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0)

权重调制卷积层

355        self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)

噪音标度

357        self.scale_noise = nn.Parameter(torch.zeros(1))

偏见

359        self.bias = nn.Parameter(torch.zeros(out_features))

激活功能

362        self.activation = nn.LeakyReLU(0.2, True)
  • x 是形状的输入要素地图[batch_size, in_features, height, width]
  • w有形状的[batch_size, d_latent]
  • noise 是形状张量[batch_size, 1, height, width]
364    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]):

获取样式矢量

371        s = self.to_style(w)

权重调制卷积

373        x = self.conv(x, s)

缩放和添加噪点

375        if noise is not None:
376            x = x + self.scale_noise[None, :, None, None] * noise

添加偏差并评估激活函数

378        return self.activation(x + self.bias[None, :, None, None])

到 RGB

To RGB

表示线性层。

使用卷积从要素地图生成 RGB 图像。

381class ToRGB(nn.Module):
  • d_latent 是的维度
  • features 是要素地图中的要素数量
394    def __init__(self, d_latent: int, features: int):
399        super().__init__()
402        self.to_style = EqualizedLinear(d_latent, features, bias=1.0)

没有解调的权重调制卷积层

405        self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)

偏见

407        self.bias = nn.Parameter(torch.zeros(3))

激活功能

409        self.activation = nn.LeakyReLU(0.2, True)
  • x 是形状的输入要素地图[batch_size, in_features, height, width]
  • w有形状的[batch_size, d_latent]
  • 411    def forward(self, x: torch.Tensor, w: torch.Tensor):

    获取样式矢量

    417        style = self.to_style(w)

    权重调制卷积

    419        x = self.conv(x, style)

    添加偏差并评估激活函数

    421        return self.activation(x + self.bias[None, :, None, None])

    带权重调制和解调的卷积

    该图层按样式向量缩放卷积权重,并通过归一化来进行解调。

    424class Conv2dWeightModulate(nn.Module):
    • in_features 是输入要素地图中的要素数
    • out_features 是输出要素地图中的要素数
    • kernel_size 是卷积内核的大小
    • demodulate 是标志是否根据权重的标准差归一化权重
    • eps用于规范化的
    431    def __init__(self, in_features: int, out_features: int, kernel_size: int,
    432                 demodulate: float = True, eps: float = 1e-8):
    440        super().__init__()

    输出要素的数量

    442        self.out_features = out_features

    是否规格化权重

    444        self.demodulate = demodulate

    填充大小

    446        self.padding = (kernel_size - 1) // 2
    449        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])

    451        self.eps = eps
    • x 是形状的输入要素地图[batch_size, in_features, height, width]
    • s 是基于样式的形状缩放张量[batch_size, in_features]
    453    def forward(self, x: torch.Tensor, s: torch.Tensor):

    获取批次大小、高度和宽度

    460        b, _, h, w = x.shape

    重塑天平

    463        s = s[:, None, :, None, None]
    465        weights = self.weight()[None, :, :, :, :]

    其中,是输入通道,是输出通道,是内核索引。

    结果有形状[batch_size, out_features, in_features, kernel_size, kernel_size]

    470        weights = weights * s

    解调

    473        if self.demodulate:

    475            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)

    477            weights = weights * sigma_inv

    重塑x

    480        x = x.reshape(1, -1, h, w)

    重塑权重

    483        _, _, *ws = weights.shape
    484        weights = weights.reshape(b * self.out_features, *ws)

    使用分组卷积来使用样本明智的内核有效地计算卷积。也就是说,我们在批处理中的每个样本都有不同的内核(权重)

    488        x = F.conv2d(x, weights, padding=self.padding, groups=b)

    重塑x[batch_size, out_features, height, width] 然后返回

    491        return x.reshape(-1, self.out_features, h, w)

    StyleGan 2 鉴别器

    Discriminator

    鉴别器首先将图像转换为具有相同分辨率的特征图,然后通过一系列具有剩余连接的块进行运行。在每个区块处对分辨率进行下采样,同时将要素数量增加一倍。

    494class Discriminator(nn.Module):
    • log_resolution 是图像分辨率的
    • n_features 卷积层中分辨率最高的要素数(第一个块)
    • max_features 任何发电机组中要素的最大数目
    508    def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512):
    514        super().__init__()

    用于将 RGB 图像转换为具有n_features 多个要素的要素地图的图层。

    517        self.from_rgb = nn.Sequential(
    518            EqualizedConv2d(3, n_features, 1),
    519            nn.LeakyReLU(0.2, True),
    520        )

    计算每个区块的要素数量。

    有点像[64, 128, 256, 512, 512, 512]

    525        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]

    区分器块的数量

    527        n_blocks = len(features) - 1

    鉴别器块

    529        blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
    530        self.blocks = nn.Sequential(*blocks)
    533        self.std_dev = MiniBatchStdDev()

    添加标准差地图后的要素数量

    535        final_features = features[-1] + 1

    最终卷积层

    537        self.conv = EqualizedConv2d(final_features, final_features, 3)

    获得分类的最终线性层

    539        self.final = EqualizedLinear(2 * 2 * final_features, 1)
    • x 是形状的输入图像[batch_size, 3, height, width]
    541    def forward(self, x: torch.Tensor):

    尝试规范化图像(这完全是可选的,但稍微加快了早期训练)

    547        x = x - 0.5

    从 RGB 进行转换

    549        x = self.from_rgb(x)

    通过鉴别器块

    551        x = self.blocks(x)

    计算并追加小批量标准差

    554        x = self.std_dev(x)

    卷积

    556        x = self.conv(x)

    压平

    558        x = x.reshape(x.shape[0], -1)

    返回分类分数

    560        return self.final(x)

    鉴别器块

    Discriminator block

    鉴别器模块由两个带有剩余连接的卷积组成。

    563class DiscriminatorBlock(nn.Module):
    • in_features 是输入要素地图中的要素数
    • out_features 是输出要素地图中的要素数
    574    def __init__(self, in_features, out_features):
    579        super().__init__()

    剩余连接的下采样和卷积层

    581        self.residual = nn.Sequential(DownSample(),
    582                                      EqualizedConv2d(in_features, out_features, kernel_size=1))

    两次卷积

    585        self.block = nn.Sequential(
    586            EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
    587            nn.LeakyReLU(0.2, True),
    588            EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
    589            nn.LeakyReLU(0.2, True),
    590        )

    向下采样层

    593        self.down_sample = DownSample()

    添加残差后的缩放系数

    596        self.scale = 1 / math.sqrt(2)
    598    def forward(self, x):

    获取剩余连接

    600        residual = self.residual(x)

    卷积

    603        x = self.block(x)

    向下采样

    605        x = self.down_sample(x)

    添加残差和比例

    608        return (x + residual) * self.scale

    小批量标准差

    小批量标准差计算要素映射中每个要素的小批次(或微型批次中的子组)的标准差。然后,它取所有标准差的平均值,并将其作为一项额外要素附加到要素地图中。

    611class MiniBatchStdDev(nn.Module):
    • group_size 是要计算标准差的样本数。
    623    def __init__(self, group_size: int = 4):
    627        super().__init__()
    628        self.group_size = group_size
    • x 是要素地图
    630    def forward(self, x: torch.Tensor):

    检查批次大小是否可以被组大小整除

    635        assert x.shape[0] % self.group_size == 0

    将样本分成几组group_size ,我们将特征图展平为单个维度,因为我们要计算每个要素的标准差。

    638        grouped = x.view(self.group_size, -1)

    计算group_size 样本中每个特征的标准差

    645        std = torch.sqrt(grouped.var(dim=0) + 1e-8)

    获取平均标准差

    647        std = std.mean().view(1, 1, 1, 1)

    展开要追加到要素地图的标准差

    649        b, _, h, w = x.shape
    650        std = std.expand(b, -1, h, w)

    将标准差追加(连接)到要素地图

    652        return torch.cat([x, std], dim=1)

    向下采样

    下采样操作使用双线性插值法平滑每个特征通道和缩放。这是基于论文《让卷积网络再次移位不变》。

    655class DownSample(nn.Module):
    667    def __init__(self):
    668        super().__init__()

    平滑层

    670        self.smooth = Smooth()
    672    def forward(self, x: torch.Tensor):

    平滑或模糊

    674        x = self.smooth(x)

    缩小规模

    676        return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False)

    向上采样

    上采样操作将图像向上缩放平滑每个特征通道。这是基于论文《让卷积网络再次移位不变》。

    679class UpSample(nn.Module):
    690    def __init__(self):
    691        super().__init__()

    向上采样层

    693        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    平滑层

    695        self.smooth = Smooth()
    697    def forward(self, x: torch.Tensor):

    向上采样和平滑

    699        return self.smooth(self.up_sample(x))

    平滑层

    该图层模糊了每个通道

    702class Smooth(nn.Module):
    711    def __init__(self):
    712        super().__init__()

    模糊内核

    714        kernel = [[1, 2, 1],
    715                  [2, 4, 2],
    716                  [1, 2, 1]]

    将内核转换为 PyTorch 张量

    718        kernel = torch.tensor([[kernel]], dtype=torch.float)

    规范化内核

    720        kernel /= kernel.sum()

    将内核另存为固定参数(不更新渐变)

    722        self.kernel = nn.Parameter(kernel, requires_grad=False)

    填充层

    724        self.pad = nn.ReplicationPad2d(1)
    726    def forward(self, x: torch.Tensor):

    获取输入要素地图的形状

    728        b, c, h, w = x.shape

    重塑以实现平滑

    730        x = x.view(-1, 1, h, w)

    添加填充

    733        x = self.pad(x)

    使用内核平滑(模糊)

    736        x = F.conv2d(x, self.kernel)

    重塑并返回

    739        return x.view(b, c, h, w)

    学习速率均衡线性层

    这使用线性图层的学习速率均衡权重

    742class EqualizedLinear(nn.Module):
    • in_features 是输入要素地图中的要素数
    • out_features 是输出要素地图中的要素数
    • bias 是偏置初始化常数
    751    def __init__(self, in_features: int, out_features: int, bias: float = 0.):
    758        super().__init__()
    760        self.weight = EqualizedWeight([out_features, in_features])

    偏见

    762        self.bias = nn.Parameter(torch.ones(out_features) * bias)
    764    def forward(self, x: torch.Tensor):

    线性变换

    766        return F.linear(x, self.weight(), bias=self.bias)

    学习速率均衡的 2D 卷积层

    这使用卷积层的学习速率均衡权重

    769class EqualizedConv2d(nn.Module):
    • in_features 是输入要素地图中的要素数
    • out_features 是输出要素地图中的要素数
    • kernel_size 是卷积内核的大小
    • padding 是要在每个尺寸维度的两边添加的内边距
    778    def __init__(self, in_features: int, out_features: int,
    779                 kernel_size: int, padding: int = 0):
    786        super().__init__()

    填充大小

    788        self.padding = padding
    790        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])

    偏见

    792        self.bias = nn.Parameter(torch.ones(out_features))
    794    def forward(self, x: torch.Tensor):

    卷积

    796        return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)

    学习速率均衡权重参数

    这是基于 Progressive GAN 论文中介绍的均衡学习率。它们不是在初始化权重,而是将权重初始化为,然后在使用时将其乘以。

    存储参数的梯度会被乘以,但这不会产生影响,因为像 Adam 这样的优化器将它们归一化为梯度的平方。

    上的优化器更新与学习速率成正比。但是有效权重会按比例更新。如果没有均衡的学习率,有效权重将按比例更新为 just

    因此,我们正在有效地缩放这些权重参数的学习速率。

    799class EqualizedWeight(nn.Module):
    • shape 是权重参数的形状
    820    def __init__(self, shape: List[int]):
    824        super().__init__()

    他初始化常量

    827        self.c = 1 / math.sqrt(np.prod(shape[1:]))

    使用初始化权重

    829        self.weight = nn.Parameter(torch.randn(shape))

    权重乘法系数

    832    def forward(self):

    将权重乘以并返回

    834        return self.weight * self.c

    梯度惩罚

    这是论文《哪种针对 GAN 的训练方法实际上会收敛?》中的正则化惩罚

    也就是说,对于真实图像,我们尝试减少鉴别器相对于图像的梯度的 L2 范数 ()。

    837class GradientPenalty(nn.Module):
    • x
  • d
  • 853    def forward(self, x: torch.Tensor, d: torch.Tensor):

    获取批次大小

    860        batch_size = x.shape[0]

    计算相对于的梯度grad_outputs 设置为,因为我们想要梯度,并且我们需要创建和保留图形,因为我们必须计算相对于此损失的权重的梯度。

    866        gradients, *_ = torch.autograd.grad(outputs=d,
    867                                            inputs=x,
    868                                            grad_outputs=d.new_ones(d.shape),
    869                                            create_graph=True)

    重塑梯度以计算范数

    872        gradients = gradients.reshape(batch_size, -1)

    计算常数

    874        norm = gradients.norm(2, dim=-1)

    退还损失

    876        return torch.mean(norm ** 2)

    路径长度惩罚

    这种正则化鼓励采用固定大小的步进,从而导致图像中的固定幅度变化。

    其中是 Jacobian从测绘网络中采样的,并且是有噪点的图像

    是训练进行时的指数移动平均线。

    计算时未使用显式计算雅可比式

    879class PathLengthPenalty(nn.Module):
    • beta用于计算指数移动平均线的常数
    903    def __init__(self, beta: float):
    907        super().__init__()

    910        self.beta = beta

    计算的步数

    912        self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)

    训练第-步,其中是它的值的指数和

    916        self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)
    • w 是形状的批次[batch_size, d_latent]
    • x 是生成的形状图像[batch_size, 3, height, width]
    918    def forward(self, w: torch.Tensor, x: torch.Tensor):

    拿到设备

    925        device = x.device

    获取像素数

    927        image_size = x.shape[2] * x.shape[3]

    计算

    929        y = torch.randn(x.shape, device=device)

    按图像大小的平方根进行计算和归一化。这是本文中未提及的缩放,但已在实施中提及。

    933        output = (x * y).sum() / math.sqrt(image_size)

    计算梯度以获取

    936        gradients, *_ = torch.autograd.grad(outputs=output,
    937                                            inputs=w,
    938                                            grad_outputs=torch.ones(output.shape, device=device),
    939                                            create_graph=True)

    计算 L2 范数

    942        norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()

    第一步后正规化

    945        if self.steps > 0:

    计算

    948            a = self.exp_sum_a / (1 - self.beta ** self.steps)

    计算罚款

    952            loss = torch.mean((norm - a) ** 2)
    953        else:

    如果我们无法计算,则返回虚拟损失

    955            loss = norm.new_tensor(0)

    计算的均值

    958        mean = norm.mean().detach()

    更新指数和

    960        self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)

    增量

    962        self.steps.add_(1.)

    退还罚款

    965        return loss