这是 P yTorch 实现实例规范化:快速风格化的缺失成分。
引入了实例规范化以改进样式传输。它基于这样的观察,即风格化不应依赖于内容图像的对比度。“对比度标准化” 是
其中,是一批具有尺寸图像索引、特征通道和空间位置的图像。
由于卷积网络很难学习 “对比度归一化”,本文介绍了实例规范化来做到这一点。
以下是使用实例规范化的 CIFAR 10 分类模型。
29import torch
30from torch import nn
31
32from labml_helpers.module import Module
35class InstanceNorm(Module):
channels
是输入中的要素数eps
是,用于数值稳定性affine
是否缩放和移动归一化值51 def __init__(self, channels: int, *,
52 eps: float = 1e-5, affine: bool = True):
58 super().__init__()
59
60 self.channels = channels
61
62 self.eps = eps
63 self.affine = affine
为缩放和移位创建参数
65 if self.affine:
66 self.scale = nn.Parameter(torch.ones(channels))
67 self.shift = nn.Parameter(torch.zeros(channels))
x
是形状张量[batch_size, channels, *]
。*
表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]
69 def forward(self, x: torch.Tensor):
保持原始形状
77 x_shape = x.shape
获取批次大小
79 batch_size = x_shape[0]
进行健全性检查以确保要素数量相同
81 assert self.channels == x.shape[1]
重塑成[batch_size, channels, n]
84 x = x.view(batch_size, self.channels, -1)
计算最后一个维度的平均值,即每个要素的均值
88 mean = x.mean(dim=[-1], keepdim=True)
计算第一维和最后一个维度的均方值;即每个要素的均值
91 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
每个要素的方差
93 var = mean_x2 - mean ** 2
规范化
96 x_norm = (x - mean) / torch.sqrt(var + self.eps)
97 x_norm = x_norm.view(batch_size, self.channels, -1)
缩放和移动
100 if self.affine:
101 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
重塑为原始形状然后返回
104 return x_norm.view(x_shape)
简单测试
107def _test():
111 from labml.logger import inspect
112
113 x = torch.zeros([2, 6, 2, 4])
114 inspect(x.shape)
115 bn = InstanceNorm(6)
116
117 x = bn(x)
118 inspect(x.shape)
122if __name__ == '__main__':
123 _test()