层规范化

这是层规范化PyTorch 实现。

批量标准化的局限性

  • 你需要保持跑步手段。
  • 对于 RNN 来说很棘手。每个步骤都需要不同的规范化吗?
  • 不适用于小批量;大型 NLP 模型通常使用小批量进行训练。
  • 需要在分布式训练中计算设备间的均值和方差。

层规范化

图层归一化是一种更简单的归一化方法,适用于更广泛的设置。图层归一化会将输入变换为各要素的均值和单位方差为零。请注意,批量归一化修复了每个元素的零均值和单位方差。层归一化对所有元素的每个批次执行此操作。

层归一化通常用于 NLP 任务。

我们在大多数变压器实现中都使用了层归一化。

35from typing import Union, List
36
37import torch
38from torch import nn, Size
39
40from labml_helpers.module import Module

层规范化

图层归一化将输入归一化,如下所示:

当输入是一批嵌入时,其中是批次大小,是要素的数量。

当 input 是嵌入序列中的一批时,其中是批次大小,是通道数,是顺序。

当输入是一批图像表示时,其中是批次大小,是通道数,是高度和是宽度。这不是一个广泛使用的场景。

43class LayerNorm(Module):
  • normalized_shape 是元素的形状(批次除外)。那么输入应该是
  • eps用于数值稳定性
  • elementwise_affine 是否缩放和移动归一化值

我们已经尝试使用与 PyTorchLayerNorm 实现相同的参数名称。

72    def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73                 eps: float = 1e-5,
74                 elementwise_affine: bool = True):
84        super().__init__()

转换normalized_shapetorch.Size

87        if isinstance(normalized_shape, int):
88            normalized_shape = torch.Size([normalized_shape])
89        elif isinstance(normalized_shape, list):
90            normalized_shape = torch.Size(normalized_shape)
91        assert isinstance(normalized_shape, torch.Size)

94        self.normalized_shape = normalized_shape
95        self.eps = eps
96        self.elementwise_affine = elementwise_affine

为增益偏置创建参数和参数

98        if self.elementwise_affine:
99            self.gain = nn.Parameter(torch.ones(normalized_shape))
100            self.bias = nn.Parameter(torch.zeros(normalized_shape))

x 是形状张量[*, S[0], S[1], ..., S[n]]* 可以是任意数量的维度。例如,在 NLP 任务中,这将是[seq_len, batch_size, features]

102    def forward(self, x: torch.Tensor):

进行健全性检查以确保形状匹配

110        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

用于计算均值和方差的维度

113        dims = [-(i + 1) for i in range(len(self.normalized_shape))]

计算所有元素的均值;即每个元素的均值

117        mean = x.mean(dim=dims, keepdim=True)

计算所有元素的均方值;即每个元素的均值

120        mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)

所有元素的方差

122        var = mean_x2 - mean ** 2

规范化

125        x_norm = (x - mean) / torch.sqrt(var + self.eps)

缩放和移动

127        if self.elementwise_affine:
128            x_norm = self.gain * x_norm + self.bias

131        return x_norm

简单测试

134def _test():
138    from labml.logger import inspect
139
140    x = torch.zeros([2, 3, 2, 4])
141    inspect(x.shape)
142    ln = LayerNorm(x.shape[2:])
143
144    x = ln(x)
145    inspect(x.shape)
146    inspect(ln.gain.shape)

150if __name__ == '__main__':
151    _test()