图层归一化是一种更简单的归一化方法,适用于更广泛的设置。图层归一化会将输入变换为各要素的均值和单位方差为零。请注意,批量归一化修复了每个元素的零均值和单位方差。层归一化对所有元素的每个批次执行此操作。
层归一化通常用于 NLP 任务。
我们在大多数变压器实现中都使用了层归一化。
35from typing import Union, List
36
37import torch
38from torch import nn, Size
39
40from labml_helpers.module import Module
图层归一化将输入归一化,如下所示:
当输入是一批嵌入时,其中是批次大小,是要素的数量。和。
当 input 是嵌入序列中的一批时,其中是批次大小,是通道数,是顺序。和。
当输入是一批图像表示时,其中是批次大小,是通道数,是高度和是宽度。这不是一个广泛使用的场景。和。
43class LayerNorm(Module):
normalized_shape
是元素的形状(批次除外)。那么输入应该是eps
是,用于数值稳定性elementwise_affine
是否缩放和移动归一化值我们已经尝试使用与 PyTorchLayerNorm
实现相同的参数名称。
72 def __init__(self, normalized_shape: Union[int, List[int], Size], *,
73 eps: float = 1e-5,
74 elementwise_affine: bool = True):
84 super().__init__()
转换normalized_shape
为torch.Size
87 if isinstance(normalized_shape, int):
88 normalized_shape = torch.Size([normalized_shape])
89 elif isinstance(normalized_shape, list):
90 normalized_shape = torch.Size(normalized_shape)
91 assert isinstance(normalized_shape, torch.Size)
94 self.normalized_shape = normalized_shape
95 self.eps = eps
96 self.elementwise_affine = elementwise_affine
为增益和偏置创建参数和参数
98 if self.elementwise_affine:
99 self.gain = nn.Parameter(torch.ones(normalized_shape))
100 self.bias = nn.Parameter(torch.zeros(normalized_shape))
x
是形状张量[*, S[0], S[1], ..., S[n]]
。*
可以是任意数量的维度。例如,在 NLP 任务中,这将是[seq_len, batch_size, features]
102 def forward(self, x: torch.Tensor):
进行健全性检查以确保形状匹配
110 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
用于计算均值和方差的维度
113 dims = [-(i + 1) for i in range(len(self.normalized_shape))]
计算所有元素的均值;即每个元素的均值
117 mean = x.mean(dim=dims, keepdim=True)
计算所有元素的均方值;即每个元素的均值
120 mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
所有元素的方差
122 var = mean_x2 - mean ** 2
规范化
125 x_norm = (x - mean) / torch.sqrt(var + self.eps)
缩放和移动
127 if self.elementwise_affine:
128 x_norm = self.gain * x_norm + self.bias
131 return x_norm
简单测试
134def _test():
138 from labml.logger import inspect
139
140 x = torch.zeros([2, 3, 2, 4])
141 inspect(x.shape)
142 ln = LayerNorm(x.shape[2:])
143
144 x = ln(x)
145 inspect(x.shape)
146 inspect(ln.gain.shape)
150if __name__ == '__main__':
151 _test()