这是基于 AdaBeLief Optimizer 论文《AdaBeLief Optimizer:通过对观察到的梯度的信念调整步长》的官方实现。
这是在 PyTorch 中作为对 RadAM 的扩展实现的。
Adam optimizer 和 AdaBeLief 之间的主要区别在于,它如何计算自适应学习率;AdaBeLief 不是除以梯度平方的指数移动平均值,而是除以方差的指数均值。
🤔 本文将方差计算为,但我认为它应该使用偏差校正的动量。我想这对事情的影响不大,因为偏差校正是在最初的训练步骤之后进行的。
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam
params
是参数列表lr
是学习率betas
是 (,) 的元组eps
是或基于optimized_update
weight_decay
是在中WeightDecay
定义的类的实例 __init__.py
optimized_update
是一个标志,是否在添加后通过这样做来优化第二个时刻的偏差校正amsgrad
是一个标志,指示是使用 AmsGrad 还是回退到普通的 Adamdegenerate_to_sgd
纠正条款难以处理时是否使用 sgdrectify
是否使用 raDAM 更新defaults
是组值的默认字典。当你想扩展类时,这很有用AdaBelief
。52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):
73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectify
77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85 state['step'] = 0
梯度值的指数移动平均线
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
方差的指数移动平均线
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
如果 famsgrad
lagTrue
用于此参数组,则我们维持方差的指数移动平均线的最大值
93 if group['amsgrad']:
保持所有 exp. 移动平均 sq. grad. 值的最大值
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
获取和
107 beta1, beta2 = group['betas']
获取和
110 m, s = state['exp_avg'], state['exp_avg_var']
就地计算
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)
梯度和动量之间的区别
116 grad_residual = grad - m
就地计算
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
如果此参数组正在使用amsgrad
122 if group['amsgrad']:
得到。
124 s_max = state['max_exp_avg_var']
计算。
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:
否则
131 return m, s
133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
计算体重衰减
144 grad = self.weight_decay(param, grad, group)
获取和
147 m, s = self.get_ms(state, group, grad)
增加优化器步数
150 state['step'] += 1
151
152 if not self.rectify:
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: