# 实例规范化

29import torch
30from torch import nn
31
32from labml_helpers.module import Module

## 实例规范化层

35class InstanceNorm(Module):
• channels 是输入中的要素数
• eps用于数值稳定性
• affine 是否缩放和移动归一化值
51    def __init__(self, channels: int, *,
52                 eps: float = 1e-5, affine: bool = True):
58        super().__init__()
59
60        self.channels = channels
61
62        self.eps = eps
63        self.affine = affine

65        if self.affine:
66            self.scale = nn.Parameter(torch.ones(channels))
67            self.shift = nn.Parameter(torch.zeros(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量（可能为 0）的维度。例如，在图像（2D）卷积中，这将是[batch_size, channels, height, width]

69    def forward(self, x: torch.Tensor):

77        x_shape = x.shape

79        batch_size = x_shape[0]

81        assert self.channels == x.shape[1]

84        x = x.view(batch_size, self.channels, -1)

88        mean = x.mean(dim=[-1], keepdim=True)

91        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

93        var = mean_x2 - mean ** 2

96        x_norm = (x - mean) / torch.sqrt(var + self.eps)
97        x_norm = x_norm.view(batch_size, self.channels, -1)

100        if self.affine:
101            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

104        return x_norm.view(x_shape)

107def _test():
111    from labml.logger import inspect
112
113    x = torch.zeros([2, 6, 2, 4])
114    inspect(x.shape)
115    bn = InstanceNorm(6)
116
117    x = bn(x)
118    inspect(x.shape)
122if __name__ == '__main__':
123    _test()