批量标准化适用于足够大的批量大小,但对于小批量来说却不太好,因为它会对批次进行标准化。由于设备的内存容量,无法训练批量较大的大型模型。
本文介绍了群组归一化,它将一组特征归一化为一个组。这是基于这样的观察,即诸如 SIFT 和 HO G之类的经典特征是按组划分的特征。该论文建议将特征信道分成组,然后分别对每个组内的所有信道进行标准化。
所有标准化层都可以通过以下计算来定义。
其中,是代表批次的张量,是单个值的索引。例如,当它是 2D 图像时,它是一个四维矢量,用于在批处理、特征通道、垂直坐标和水平坐标内对图像进行索引。并且是均值和标准差。
是一组指数,通过该指数计算指数的均值和标准差。是集合的大小,对所有人来说都是一样的。
共享相同功能通道的值将一起标准化。
批次中同一样本的值一起标准化。
来自相同样本和相同特征通道的值被归一化在一起。
其中,是群组数量,是频道数。
分组归一化将相同样本和同一组通道的值一起归一化。
这是使用实例标准化的 CIFAR 10 分类模型。
84import torch
85from torch import nn
86
87from labml_helpers.module import Module
90class GroupNorm(Module):
groups
是要素被划分到的组的数量channels
是输入中的要素数eps
是,用于数值稳定性affine
是否缩放和移动归一化值95 def __init__(self, groups: int, channels: int, *,
96 eps: float = 1e-5, affine: bool = True):
103 super().__init__()
104
105 assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
106 self.groups = groups
107 self.channels = channels
108
109 self.eps = eps
110 self.affine = affine
为缩放和移位创建参数
112 if self.affine:
113 self.scale = nn.Parameter(torch.ones(channels))
114 self.shift = nn.Parameter(torch.zeros(channels))
x
是形状张量[batch_size, channels, *]
。*
表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]
116 def forward(self, x: torch.Tensor):
保持原始形状
124 x_shape = x.shape
获取批次大小
126 batch_size = x_shape[0]
进行健全性检查以确保要素数量相同
128 assert self.channels == x.shape[1]
重塑成[batch_size, groups, n]
131 x = x.view(batch_size, self.groups, -1)
计算最后一个维度的均值;即每个样本和通道组的均值
135 mean = x.mean(dim=[-1], keepdim=True)
计算最后一个维度的均方值;即每个样本和通道组的均值
138 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
每个样本和特征组的方差
141 var = mean_x2 - mean ** 2
规范化
146 x_norm = (x - mean) / torch.sqrt(var + self.eps)
按通道缩放和移动
150 if self.affine:
151 x_norm = x_norm.view(batch_size, self.channels, -1)
152 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
重塑为原始形状然后返回
155 return x_norm.view(x_shape)
简单测试
158def _test():
162 from labml.logger import inspect
163
164 x = torch.zeros([2, 6, 2, 4])
165 inspect(x.shape)
166 bn = GroupNorm(2, 6)
167
168 x = bn(x)
169 inspect(x.shape)
173if __name__ == '__main__':
174 _test()