批处理信道标准化

这是 PyTorch 实现的批处理通道标准化,来自论文《使用批处理通道标准化和权重标准化进行微批量训练》。我们还有一个带注释的重量标准化实现方案

批处理通道标准化先执行批量标准化,然后进行信道标准化(类似于组标准化)。当批次大小很小时,使用运行均值和方差进行批量标准化。

以下是训练 VGG 网络的训练代码,该网络使用权重标准化对 CIFAR-10 数据进行分类。

Open In Colab

25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm

批量信道规范化

这首先执行批次归一化——正态批次范数或具有估计均值和方差(多个批次的指数均值/方差)的批次范数。然后执行了信道标准化。

32class BatchChannelNorm(Module):
  • channels 是输入中的要素数
  • groups 是要素被划分到的组的数量
  • eps用于数值稳定性
  • momentum 是取指数移动平均线的动量
  • estimate 是否使用运行均值和方差作为批次范数
42    def __init__(self, channels: int, groups: int,
43                 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51        super().__init__()

使用估计的批次规范或普通批次规范。

54        if estimate:
55            self.batch_norm = EstimatedBatchNorm(channels,
56                                                 eps=eps, momentum=momentum)
57        else:
58            self.batch_norm = BatchNorm(channels,
59                                        eps=eps, momentum=momentum)

信道规范化

62        self.channel_norm = ChannelNorm(channels, groups, eps)
64    def forward(self, x):
65        x = self.batch_norm(x)
66        return self.channel_norm(x)

预计批次规范化

当输入是一批图像表示时,其中是批次大小,是通道数,是高度和是宽度。

在哪里,

是运行均值和方差。是计算指数均值的动量。

69class EstimatedBatchNorm(Module):
  • channels 是输入中的要素数
  • eps用于数值稳定性
  • momentum 是取指数移动平均线的动量
  • estimate 是否使用运行均值和方差作为批次范数
90    def __init__(self, channels: int,
91                 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98        super().__init__()
99
100        self.eps = eps
101        self.momentum = momentum
102        self.affine = affine
103        self.channels = channels

频道变换参数

106        if self.affine:
107            self.scale = nn.Parameter(torch.ones(channels))
108            self.shift = nn.Parameter(torch.zeros(channels))

和的张量

111        self.register_buffer('exp_mean', torch.zeros(channels))
112        self.register_buffer('exp_var', torch.ones(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]

114    def forward(self, x: torch.Tensor):

保持旧的形状

122        x_shape = x.shape

获取批次大小

124        batch_size = x_shape[0]

进行健全性检查以确保要素数量正确

127        assert self.channels == x.shape[1]

重塑成[batch_size, channels, n]

130        x = x.view(batch_size, self.channels, -1)

更新且仅在训练模式下

133        if self.training:

没有通过和的反向传播

135            with torch.no_grad():

计算第一维和最后一个维度的平均值;

138                mean = x.mean(dim=[0, 2])

计算第一维和最后一个维度的均方值;

141                mean_x2 = (x ** 2).mean(dim=[0, 2])

每个要素的方差

144                var = mean_x2 - mean ** 2

更新指数移动平均线

152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

规范化

157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)

缩放和移动

162        if self.affine:
163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

重塑为原始形状然后返回

166        return x_norm.view(x_shape)

频道规范化

这与组归一化类似,但仿射变换是按组进行的。

169class ChannelNorm(Module):
  • groups 是要素被划分到的组的数量
  • channels 是输入中的要素数
  • eps用于数值稳定性
  • affine 是否缩放和移动归一化值
176    def __init__(self, channels, groups,
177                 eps: float = 1e-5, affine: bool = True):
184        super().__init__()
185        self.channels = channels
186        self.groups = groups
187        self.eps = eps
188        self.affine = affine

仿射变换的参数。

请注意,这些变换是按组进行的,这与组规范不同,它们是按通道变换的。

193        if self.affine:
194            self.scale = nn.Parameter(torch.ones(groups))
195            self.shift = nn.Parameter(torch.zeros(groups))

x 是形状张量[batch_size, channels, *]* 表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]

197    def forward(self, x: torch.Tensor):

保持原始形状

206        x_shape = x.shape

获取批次大小

208        batch_size = x_shape[0]

进行健全性检查以确保要素数量相同

210        assert self.channels == x.shape[1]

重塑成[batch_size, groups, n]

213        x = x.view(batch_size, self.groups, -1)

计算最后一个维度的均值;即每个样本和通道组的均值

217        mean = x.mean(dim=[-1], keepdim=True)

计算最后一个维度的均方值;即每个样本和通道组的均值

220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

每个样本和特征组的方差

223        var = mean_x2 - mean ** 2

规范化

228        x_norm = (x - mean) / torch.sqrt(var + self.eps)

按组缩放和移动

232        if self.affine:
233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

重塑为原始形状然后返回

236        return x_norm.view(x_shape)