# 批量标准化

## 推断

97import torch
98from torch import nn
99
100from labml_helpers.module import Module

## 批量归一化层

103class BatchNorm(Module):
• channels 是输入中的要素数
• eps用于数值稳定性
• momentum 是取指数移动平均线的动量
• affine 是否缩放和移动归一化值
• track_running_stats 是计算移动平均线还是均值和方差

131    def __init__(self, channels: int, *,
132                 eps: float = 1e-5, momentum: float = 0.1,
133                 affine: bool = True, track_running_stats: bool = True):
143        super().__init__()
144
145        self.channels = channels
146
147        self.eps = eps
148        self.momentum = momentum
149        self.affine = affine
150        self.track_running_stats = track_running_stats

152        if self.affine:
153            self.scale = nn.Parameter(torch.ones(channels))
154            self.shift = nn.Parameter(torch.zeros(channels))

157        if self.track_running_stats:
158            self.register_buffer('exp_mean', torch.zeros(channels))
159            self.register_buffer('exp_var', torch.ones(channels))

x 是形状张量[batch_size, channels, *]* 表示任意数量（可能为 0）的维度。例如，在图像（2D）卷积中，这将是[batch_size, channels, height, width]

161    def forward(self, x: torch.Tensor):

169        x_shape = x.shape

171        batch_size = x_shape[0]

173        assert self.channels == x.shape[1]

176        x = x.view(batch_size, self.channels, -1)

180        if self.training or not self.track_running_stats:

183            mean = x.mean(dim=[0, 2])

186            mean_x2 = (x ** 2).mean(dim=[0, 2])

188            var = mean_x2 - mean ** 2

191            if self.training and self.track_running_stats:
192                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
193                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var

195        else:
196            mean = self.exp_mean
197            var = self.exp_var

200        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)

202        if self.affine:
203            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

206        return x_norm.view(x_shape)

209def _test():
213    from labml.logger import inspect
214
215    x = torch.zeros([2, 3, 2, 4])
216    inspect(x.shape)
217    bn = BatchNorm(3)
218
219    x = bn(x)
220    inspect(x.shape)
221    inspect(bn.exp_var.shape)
225if __name__ == '__main__':
226    _test()