这是论文《Dee pNe t:将变形金刚扩展到 1,000 层》中 DeepN orm 的 PyTorch 实现。
本文提出了一种通过新的归一化函数取代LayerNorm和权重初始化方案,稳定极深变压器的方法。这结合了 postLayerNorm 的性能和 pre-LayerNorm 的稳定性。即使没有学习速率预热,带有DeepNorms的变形金刚也应该保持稳定。
本文首先表明,在稳定训练期间,图层输出(针对相同输入)的变化会逐渐变化;当不稳定时,它在最初的训练步骤中会迅速变化。这种情况发生在将权重初始化为小值以及训练稳定时进行学习率预热时。他们使用保持对图层输出的更改较小的想法来推导出新的标准化和权重初始化机制。
通常,权重是使用 Xavier 或 Kaiming 初始化进行初始化的。这张纸根据变压器的大小将权重缩放(设置增益)一个常数。
DeepNorm 建议缩放前馈网络中两个线性变换、价值投影变换和注意力层输出投影变换的权重。这些变换的权重按比例缩放(增益等于)。
扩展是在中实现的
其中,是取决于变压器深度的常数,是层归一化,是第 -8 个变压器子层的函数(FFN 或注意力)。
此函数用于替换 postLayerNorm。
其中是编码器中的层数,是解码器中的层数。
请参考论文进行推导。
73from typing import Union, List
74
75import torch
76from torch import nn, Size
77
78from labml_nn.normalization.layer_norm import LayerNorm
79from labml_nn.transformers import MultiHeadAttention
80from labml_nn.transformers.feed_forward import FeedForward
81from labml_nn.transformers.utils import subsequent_mask
84class DeepNorm(nn.Module):
alpha
是normalized_shape
是 LayerNorm 的形状eps
是为 LayerNorm 准备的elementwise_affine
是一个标志,指示是否在 LayerNorm 中进行逐元素转换91 def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
92 eps: float = 1e-5,
93 elementwise_affine: bool = True):
100 super().__init__()
101
102 self.alpha = alpha
初始化
104 self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
x
是前一层的输出gx
是当前子层的输出106 def forward(self, x: torch.Tensor, gx: torch.Tensor):
112 return self.layer_norm(x + self.alpha * gx)
115class DeepNormTransformerLayer(nn.Module):
d_model
是令牌嵌入的大小self_attn
是自我关注模块feed_forward
是前馈模块deep_norm_alpha
是 DeepNorm 中的系数deep_norm_beta
对于缩放权重初始化来说是常数122 def __init__(self, *,
123 d_model: int,
124 self_attn: MultiHeadAttention,
125 feed_forward: FeedForward,
126 deep_norm_alpha: float,
127 deep_norm_beta: float,
128 ):
136 super().__init__()
137
138 self.self_attn = self_attn
139 self.feed_forward = feed_forward
DeepNorms 追随关注和前馈网络
141 self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
142 self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
初始化后缩放权重
145 with torch.no_grad():
前馈网络线性变换
147 feed_forward.layer1.weight *= deep_norm_beta
148 feed_forward.layer2.weight *= deep_norm_beta
注意力值预测
151 self_attn.value.linear.weight *= deep_norm_beta
注意输出项目
153 self_attn.output.weight *= deep_norm_beta
掩码将在第一次调用时初始化
156 self.mask = None
x
是形状的嵌入[seq_len, batch_size, d_model]
158 def forward(self, x: torch.Tensor):
创建因果面具
163 if self.mask is None or self.mask.size(0) != len(x):
后续的掩码,将掩盖令牌以免看到未来的代币
165 self.mask = subsequent_mask(len(x)).to(x.device)
通过自我关注,即关键和价值来自自我
168 x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
通过前馈网络
170 x = self.feed_forward_norm(x, self.feed_forward(x))
173 return x