群组标准化

这是 PyTorch群组标准化论文的实现。

批量标准化适用于足够大的批量大小,但对于小批量来说却不太好,因为它会对批次进行标准化。由于设备的内存容量,无法训练批量较大的大型模型。

本文介绍了群组归一化,它将一组特征归一化为一个组。这是基于这样的观察,即诸如 SIFTHO G之类的经典特征是按组划分的特征。该论文建议将特征信道分成组,然后分别对每个组内的所有信道进行标准化。

配方

所有标准化层都可以通过以下计算来定义。

其中,是代表批次的张量,是单个值的索引。例如,当它是 2D 图像时,它是一个四维矢量,用于在批处理、特征通道、垂直坐标和水平坐标内对图像进行索引。并且是均值和标准差。

是一组指数,通过该指数计算指数的均值和标准差是集合的大小,对所有人来说都是一样的

的定义对于批量标准化图层标准化和实例标准化是不同的。

批量标准化

共享相同功能通道的值将一起标准化。

图层标准化

批次中同一样本的值一起标准化。

实例标准化

来自相同样本和相同特征通道的值被归一化在一起。

群组标准化

其中,是群组数量,是频道数。

分组归一化将相同样本和同一组通道的值一起归一化。

这是使用实例标准化的 CIFAR 10 分类模型

Open In Colab

84import torch
85from torch import nn
86
87from labml_helpers.module import Module

组归一化层

90class GroupNorm(Module):
  • groups 是要素被划分到的组的数量
  • channels 是输入中的要素数
  • eps用于数值稳定性
  • affine 是否缩放和移动归一化值
95    def __init__(self, groups: int, channels: int, *,
96                 eps: float = 1e-5, affine: bool = True):
103        super().__init__()
104
105        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
106        self.groups = groups
107        self.channels = channels
108
109        self.eps = eps
110        self.affine = affine

为缩放和移位创建参数

112        if self.affine:
113            self.scale = nn.Parameter(torch.ones(channels))
114            self.shift = nn.Parameter(torch.zeros(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量(可能为 0)的维度。例如,在图像(2D)卷积中,这将是[batch_size, channels, height, width]

116    def forward(self, x: torch.Tensor):

保持原始形状

124        x_shape = x.shape

获取批次大小

126        batch_size = x_shape[0]

进行健全性检查以确保要素数量相同

128        assert self.channels == x.shape[1]

重塑成[batch_size, groups, n]

131        x = x.view(batch_size, self.groups, -1)

计算最后一个维度的均值;即每个样本和通道组的均值

135        mean = x.mean(dim=[-1], keepdim=True)

计算最后一个维度的均方值;即每个样本和通道组的均值

138        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

每个样本和特征组的方差

141        var = mean_x2 - mean ** 2

规范化

146        x_norm = (x - mean) / torch.sqrt(var + self.eps)

按通道缩放和移动

150        if self.affine:
151            x_norm = x_norm.view(batch_size, self.channels, -1)
152            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

重塑为原始形状然后返回

155        return x_norm.view(x_shape)

简单测试

158def _test():
162    from labml.logger import inspect
163
164    x = torch.zeros([2, 6, 2, 4])
165    inspect(x.shape)
166    bn = GroupNorm(2, 6)
167
168    x = bn(x)
169    inspect(x.shape)

173if __name__ == '__main__':
174    _test()