This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.
This is implemented in PyTorch as an extension to RAdam.
The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.
🤔 The paper calculates variance as , but I feel it should use the bias corrected momentum . I guess this doesn't affect things much because bias correction is after the initial training steps.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam
45class AdaBelief(RAdam):
params
is the list of parameters lr
is the learning rate betas
is a tuple of (, ) eps
is or based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
optimized_update
is a flag whether to optimize the bias correction of the second moment by doing it after adding amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adam degenerate_to_sgd
whether to use sgd when the rectification term is intractable rectify
is whether to use RAdam update defaults
is a dictionary of default for group values. This is useful when you want to extend the class AdaBelief
.52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):
73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectify
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group param
is the parameter tensor 77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85 state['step'] = 0
Exponential moving average of gradient values
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
Exponential moving average of variance
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
If amsgrad
flag is True
for this parameter group, we maintain the maximum of exponential moving average of variance
93 if group['amsgrad']:
Maintains max of all exp. moving avg. of sq. grad. values
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter 97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
Get and
107 beta1, beta2 = group['betas']
Get and
110 m, s = state['exp_avg'], state['exp_avg_var']
In-place calculation of
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)
Difference between gradient and momentum
116 grad_residual = grad - m
In-place calculation of
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
If this parameter group is using amsgrad
122 if group['amsgrad']:
Get .
124 s_max = state['max_exp_avg_var']
Calculate .
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:
and otherwise
131 return m, s
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter param
is the parameter tensor 133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
144 grad = self.weight_decay(param, grad, group)
Get and
147 m, s = self.get_ms(state, group, grad)
Increment the number of optimizer steps
150 state['step'] += 1
151
152 if not self.rectify:
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: