adaBelief 优化器

这是基于 AdaBeLief Optimizer 论文AdaBeLief Optimizer:通过对观察到的梯度的信念调整步长》的官方实现。

这是在 PyTorch 中作为对 RadAM 的扩展实现的。

Adam optimizer 和 AdaBeLief 之间的主要区别在于,它如何计算自适应学习率;AdaBeLief 不是除以梯度平方的指数移动平均值,而是除以方差的指数均值。

🤔 本文将方差计算为,但我认为它应该使用偏差校正的动量。我想这对事情的影响不大,因为偏差校正是在最初的训练步骤之后进行的。

36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam

adaBelief 优化器

这个类是从中定义的 RadAM 优化器扩展而来的radam.py

45class AdaBelief(RAdam):

初始化优化器

  • params 是参数列表
  • lr 是学习率
  • betas 是 (,) 的元组
  • eps基于optimized_update
  • weight_decay 是在中WeightDecay 定义的类的实例 __init__.py
  • optimized_update 是一个标志,是否在添加后通过这样做来优化第二个时刻的偏差校正
  • amsgrad 是一个标志,指示是使用 AmsGrad 还是回退到普通的 Adam
  • degenerate_to_sgd 纠正条款难以处理时是否使用 sgd
  • rectify 是否使用 raDAM 更新
  • defaults 是组值的默认字典。当你想扩展类时,这很有用AdaBelief
52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):
73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectify

初始化参数状态

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • param 是参数张量
77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85        state['step'] = 0

梯度值的指数移动平均线

87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

方差的指数移动平均线

89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

如果 famsgrad lagTrue 用于此参数组,则我们维持方差的指数移动平均线的最大值

93        if group['amsgrad']:

保持所有 exp. 移动平均 sq. grad. 值的最大值

95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

计算

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • grad 是参数的当前梯度张量
97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

获取

107        beta1, beta2 = group['betas']

获取

110        m, s = state['exp_avg'], state['exp_avg_var']

就地计算

114        m.mul_(beta1).add_(grad, alpha=1 - beta1)

梯度和动量之间的区别

116        grad_residual = grad - m

就地计算

119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

如果此参数组正在使用amsgrad

122        if group['amsgrad']:

得到

124            s_max = state['max_exp_avg_var']

计算

126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:

否则

131            return m, s

对给定参数张量执行更新步骤

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • grad 是参数的当前梯度张量
  • param 是参数张量
133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

计算体重衰减

144        grad = self.weight_decay(param, grad, group)

获取

147        m, s = self.get_ms(state, group, grad)

增加优化器步数

150        state['step'] += 1
151
152        if not self.rectify:

执行 Adam 更新,在中定义 adam.py ,用代替

155            self.adam_update(state, group, param, m, s + group['eps'])
156        else:

执行中定义的已校正的 Adam 更新 radam.py ,用代替

159            self.r_adam_update(state, group, param, m, s + group['eps'])