这是 PyTorch 从纸质批量规范化中实现批量归一化:通过减少内部协变量偏移加速深度网络训练。
本文将内部协变量移位定义为训练期间由于网络参数的变化而导致的网络激活分布的变化。例如,假设有两层和。在培训开始时,可以分发输出(输入)。然后,经过一些训练步骤后,它可能会移至。这是内部协变量移位。
内部协变量偏移将对训练速度产生不利影响,因为后面的图层(在上面的例子中)必须适应这种偏移分布。
通过稳定分布,批量归一化可以最大限度地减少内部协变量偏移。
众所周知,美白可以提高训练速度和收敛性。美白是将输入进行线性变换,使其均值为零、单位方差且不相关。
使用预先计算(分离)均值和方差在梯度计算之外进行归一化不起作用。例如。(忽略方差),让 wher e an d 是一个训练过的偏差,是外部梯度计算(预先计算的常量)。
请注意,这对。因此,将在每次训练更新中增加或减少,并且会无限期地增长。该报指出,类似的爆炸会发生差异。
美白在计算上很昂贵,因为你需要去关联,而且梯度必须通过完整的美白计算。
本文介绍了一个简化版本,他们称之为批量规范化。首先简化的是,它将每个要素独立归一化,使其均值和单位方差为零:其中是维度输入。
第二种简化方法是使用来自微型批次的均值和方差的估计值进行归一化;而不是计算整个数据集的均值和方差。
将每个要素归一化为零均值和单位方差可能会影响图层可以表示的内容。作为示例论文说明,如果sigmoid的输入被归一化,则大部分将在sigmoid为线性的范围内。为了克服这个问题,每个特征都通过两个经过训练的参数进行缩放和移动。其中是批量归一化层的输出。
请注意,在线性变换之后应用批量归一化时,比如偏置参数会因归一化而被取消。因此,你可以而且应该在批量归一化之前省略线性变换中的偏置参数。
批量归一化还使反向传播与权重的比例保持不变,从经验上讲,它改善了泛化,因此它也具有正则化效果。
我们需要知道 an d 才能执行规范化。因此,在推理过程中,您要么需要遍历整个(或部分)数据集并找到均值和方差,要么可以使用训练期间计算的估计值。通常的做法是在训练阶段计算均值和方差的指数移动平均线,然后将其用于推断。
以下是训练代码和用于训练 CNN 分类器的笔记本,该分类器使用 MNIST 数据集的批量归一化。
97import torch
98from torch import nn
99
100from labml_helpers.module import Module
批量归一化层将输入归一化,如下所示:
当输入是一批图像表示时,其中是批次大小,是通道数,是高度和是宽度。和。
当输入是一批嵌入时,其中是批次大小,是要素的数量。和。
当输入是一批序列嵌入时,其中是批次大小,是要素的数量,是顺序。和。
103class BatchNorm(Module):
channels
是输入中的要素数eps
是,用于数值稳定性momentum
是取指数移动平均线的动量affine
是否缩放和移动归一化值track_running_stats
是计算移动平均线还是均值和方差我们已经尝试使用与 PyTorchBatchNorm
实现相同的参数名称。
131 def __init__(self, channels: int, *,
132 eps: float = 1e-5, momentum: float = 0.1,
133 affine: bool = True, track_running_stats: bool = True):
143 super().__init__()
144
145 self.channels = channels
146
147 self.eps = eps
148 self.momentum = momentum
149 self.affine = affine
150 self.track_running_stats = track_running_stats
为缩放和移位创建参数
152 if self.affine:
153 self.scale = nn.Parameter(torch.ones(channels))
154 self.shift = nn.Parameter(torch.zeros(channels))
创建缓冲区以存储均值和方差的指数移动平均线
157 if self.track_running_stats:
158 self.register_buffer('exp_mean', torch.zeros(channels))
159 self.register_buffer('exp_var', torch.ones(channels))
x
是形状张量[batch_size, channels, *]
。*
表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]
161 def forward(self, x: torch.Tensor):
保持原始形状
169 x_shape = x.shape
获取批次大小
171 batch_size = x_shape[0]
进行健全性检查以确保要素数量相同
173 assert self.channels == x.shape[1]
重塑成[batch_size, channels, n]
176 x = x.view(batch_size, self.channels, -1)
如果我们处于训练模式或者没有跟踪指数移动平均线,我们将计算小批次均值和方差
180 if self.training or not self.track_running_stats:
计算第一维和最后一个维度的平均值;即每个要素的均值
183 mean = x.mean(dim=[0, 2])
计算第一维和最后一个维度的均方值;即每个要素的均值
186 mean_x2 = (x ** 2).mean(dim=[0, 2])
每个要素的方差
188 var = mean_x2 - mean ** 2
更新指数移动平均线
191 if self.training and self.track_running_stats:
192 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
193 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
使用指数移动平均线作为估计值
195 else:
196 mean = self.exp_mean
197 var = self.exp_var
规范化
200 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
缩放和移动
202 if self.affine:
203 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
重塑为原始形状然后返回
206 return x_norm.view(x_shape)
简单测试
209def _test():
213 from labml.logger import inspect
214
215 x = torch.zeros([2, 3, 2, 4])
216 inspect(x.shape)
217 bn = BatchNorm(3)
218
219 x = bn(x)
220 inspect(x.shape)
221 inspect(bn.exp_var.shape)
225if __name__ == '__main__':
226 _test()