这是 PyTorch 实现的批处理通道标准化,来自论文《使用批处理通道标准化和权重标准化进行微批量训练》。我们还有一个带注释的重量标准化实现方案。
批处理通道标准化先执行批量标准化,然后进行信道标准化(类似于组标准化)。当批次大小很小时,使用运行均值和方差进行批量标准化。
以下是训练 VGG 网络的训练代码,该网络使用权重标准化对 CIFAR-10 数据进行分类。
25import torch
26from torch import nn
27
28from labml_helpers.module import Module
29from labml_nn.normalization.batch_norm import BatchNorm
32class BatchChannelNorm(Module):
channels
是输入中的要素数groups
是要素被划分到的组的数量eps
是,用于数值稳定性momentum
是取指数移动平均线的动量estimate
是否使用运行均值和方差作为批次范数42 def __init__(self, channels: int, groups: int,
43 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
51 super().__init__()
使用估计的批次规范或普通批次规范。
54 if estimate:
55 self.batch_norm = EstimatedBatchNorm(channels,
56 eps=eps, momentum=momentum)
57 else:
58 self.batch_norm = BatchNorm(channels,
59 eps=eps, momentum=momentum)
信道规范化
62 self.channel_norm = ChannelNorm(channels, groups, eps)
64 def forward(self, x):
65 x = self.batch_norm(x)
66 return self.channel_norm(x)
69class EstimatedBatchNorm(Module):
channels
是输入中的要素数eps
是,用于数值稳定性momentum
是取指数移动平均线的动量estimate
是否使用运行均值和方差作为批次范数90 def __init__(self, channels: int,
91 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
98 super().__init__()
99
100 self.eps = eps
101 self.momentum = momentum
102 self.affine = affine
103 self.channels = channels
频道变换参数
106 if self.affine:
107 self.scale = nn.Parameter(torch.ones(channels))
108 self.shift = nn.Parameter(torch.zeros(channels))
和的张量
111 self.register_buffer('exp_mean', torch.zeros(channels))
112 self.register_buffer('exp_var', torch.ones(channels))
x
是形状张量[batch_size, channels, *]
。*
表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]
114 def forward(self, x: torch.Tensor):
保持旧的形状
122 x_shape = x.shape
获取批次大小
124 batch_size = x_shape[0]
进行健全性检查以确保要素数量正确
127 assert self.channels == x.shape[1]
重塑成[batch_size, channels, n]
130 x = x.view(batch_size, self.channels, -1)
更新且仅在训练模式下
133 if self.training:
没有通过和的反向传播
135 with torch.no_grad():
计算第一维和最后一个维度的平均值;
138 mean = x.mean(dim=[0, 2])
计算第一维和最后一个维度的均方值;
141 mean_x2 = (x ** 2).mean(dim=[0, 2])
每个要素的方差
144 var = mean_x2 - mean ** 2
152 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
153 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
规范化
157 x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
缩放和移动
162 if self.affine:
163 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
重塑为原始形状然后返回
166 return x_norm.view(x_shape)
groups
是要素被划分到的组的数量channels
是输入中的要素数eps
是,用于数值稳定性affine
是否缩放和移动归一化值176 def __init__(self, channels, groups,
177 eps: float = 1e-5, affine: bool = True):
184 super().__init__()
185 self.channels = channels
186 self.groups = groups
187 self.eps = eps
188 self.affine = affine
193 if self.affine:
194 self.scale = nn.Parameter(torch.ones(groups))
195 self.shift = nn.Parameter(torch.zeros(groups))
x
是形状张量[batch_size, channels, *]
。*
表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]
197 def forward(self, x: torch.Tensor):
保持原始形状
206 x_shape = x.shape
获取批次大小
208 batch_size = x_shape[0]
进行健全性检查以确保要素数量相同
210 assert self.channels == x.shape[1]
重塑成[batch_size, groups, n]
213 x = x.view(batch_size, self.groups, -1)
计算最后一个维度的均值;即每个样本和通道组的均值
217 mean = x.mean(dim=[-1], keepdim=True)
计算最后一个维度的均方值;即每个样本和通道组的均值
220 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
每个样本和特征组的方差
223 var = mean_x2 - mean ** 2
规范化
228 x_norm = (x - mean) / torch.sqrt(var + self.eps)
按组缩放和移动
232 if self.affine:
233 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
重塑为原始形状然后返回
236 return x_norm.view(x_shape)