重量标准化

这是 PyTorch 实现的权重标准化,源自论文《具有批处理通道标准化和权重标准化的微批量训练》。我们还有一个带注释的批处理通道标准化实现

批量归一化提供了平滑的损失格局避免了消除奇异性。消除奇点是网络中变得毫无用处的节点(例如,一直给出 0 的 ReLU)。

但是,当批量大小太小时,批量标准化效果不佳,由于设备内存限制,在训练大型网络时会发生这种情况。本文介绍了采用批处理信道标准化的权重标准化作为更好的替代方案。

重量标准化:1.归一化梯度 2.平滑地形(降低了 Lipschitz 常数)3.避免消除奇点

Lipschitz 常量是函数在两点之间的最大斜率。也就是说,是 Lipschitz 常数,其中是满足的最小值,其中

避免了消除奇异性,因为它使输出的统计数据与输入的统计数据相似。因此,只要输入呈正态分布,输出就保持接近正常水平。这样可以避免节点的输出总是超出激活函数的有效范围(例如,ReLU 的输入总是负数)。

有关样张,请参阅论文

以下是训练 VGG 网络的训练代码,该网络使用权重标准化对 CIFAR-10 数据进行分类。这使用了具有权重标准化功能的二维卷积层

Open In Colab

48import torch

重量标准化

在哪里,

对于 2D 卷积层,是输出通道数 (),是输入通道数乘以内核大小 (

)
51def weight_standardization(weight: torch.Tensor, eps: float):

获取和内核形状

70    c_out, c_in, *kernel_shape = weight.shape

重塑

72    weight = weight.view(c_out, -1)

计算

79    var, mean = torch.var_mean(weight, dim=1, keepdim=True)

规范化

82    weight = (weight - mean) / (torch.sqrt(var + eps))

改回原始形状并返回

84    return weight.view(c_out, c_in, *kernel_shape)