实例规范化

这是 P yTorch 实现实例规范化:快速风格化的缺失成分

引入了实例规范化以改进样式传输。它基于这样的观察,即风格化不应依赖于内容图像的对比度。“对比度标准化” 是

其中,是一批具有尺寸图像索引、特征通道和空间位置的图像

由于卷积网络很难学习 “对比度归一化”,本文介绍了实例规范化来做到这一点。

以下是使用实例规范化的 CIFAR 10 分类模型

29import torch
30from torch import nn
31
32from labml_helpers.module import Module

实例规范化层

实例归一化层将输入归一化,如下所示:

当输入是一批图像表示时,其中是批次大小,是通道数,是高度和是宽度。。带和的仿射变换是可选的。

35class InstanceNorm(Module):
  • channels 是输入中的要素数
  • eps用于数值稳定性
  • affine 是否缩放和移动归一化值
51    def __init__(self, channels: int, *,
52                 eps: float = 1e-5, affine: bool = True):
58        super().__init__()
59
60        self.channels = channels
61
62        self.eps = eps
63        self.affine = affine

为缩放和移位创建参数

65        if self.affine:
66            self.scale = nn.Parameter(torch.ones(channels))
67            self.shift = nn.Parameter(torch.zeros(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]

69    def forward(self, x: torch.Tensor):

保持原始形状

77        x_shape = x.shape

获取批次大小

79        batch_size = x_shape[0]

进行健全性检查以确保要素数量相同

81        assert self.channels == x.shape[1]

重塑成[batch_size, channels, n]

84        x = x.view(batch_size, self.channels, -1)

计算最后一个维度的平均值,即每个要素的均值

88        mean = x.mean(dim=[-1], keepdim=True)

计算第一维和最后一个维度的均方值;即每个要素的均值

91        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

每个要素的方差

93        var = mean_x2 - mean ** 2

规范化

96        x_norm = (x - mean) / torch.sqrt(var + self.eps)
97        x_norm = x_norm.view(batch_size, self.channels, -1)

缩放和移动

100        if self.affine:
101            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

重塑为原始形状然后返回

104        return x_norm.view(x_shape)

简单测试

107def _test():
111    from labml.logger import inspect
112
113    x = torch.zeros([2, 6, 2, 4])
114    inspect(x.shape)
115    bn = InstanceNorm(6)
116
117    x = bn(x)
118    inspect(x.shape)

122if __name__ == '__main__':
123    _test()