# 群组标准化

## 配方

### 群组标准化

84import torch
85from torch import nn
86
87from labml_helpers.module import Module

## 组归一化层

90class GroupNorm(Module):
• groups 是要素被划分到的组的数量
• channels 是输入中的要素数
• eps用于数值稳定性
• affine 是否缩放和移动归一化值
95    def __init__(self, groups: int, channels: int, *,
96                 eps: float = 1e-5, affine: bool = True):
103        super().__init__()
104
105        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
106        self.groups = groups
107        self.channels = channels
108
109        self.eps = eps
110        self.affine = affine

112        if self.affine:
113            self.scale = nn.Parameter(torch.ones(channels))
114            self.shift = nn.Parameter(torch.zeros(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量（可能为 0）的维度。例如，在图像（2D）卷积中，这将是[batch_size, channels, height, width]

116    def forward(self, x: torch.Tensor):

124        x_shape = x.shape

126        batch_size = x_shape[0]

128        assert self.channels == x.shape[1]

131        x = x.view(batch_size, self.groups, -1)

135        mean = x.mean(dim=[-1], keepdim=True)

138        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

141        var = mean_x2 - mean ** 2

146        x_norm = (x - mean) / torch.sqrt(var + self.eps)

150        if self.affine:
151            x_norm = x_norm.view(batch_size, self.channels, -1)
152            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

155        return x_norm.view(x_shape)

158def _test():
162    from labml.logger import inspect
163
164    x = torch.zeros([2, 6, 2, 4])
165    inspect(x.shape)
166    bn = GroupNorm(2, 6)
167
168    x = bn(x)
169    inspect(x.shape)
173if __name__ == '__main__':
174    _test()