This is implemented in PyTorch as an extension to RAdam.

The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.

🤔 The paper calculates variance as $(g_t - m_t)^2$, but I feel it should use the bias corrected momentum $(g_t - \color{orange}{\hat{m}_t})^2$. I guess this doesn’t affect things much because bias correction is $\approx 1$ after the initial training steps.

36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam

This class extends from RAdam optimizer defined in radam.py.

45class AdaBelief(RAdam):

### Initialize the optimizer

• params is the list of parameters
• lr is the learning rate $\alpha$
• betas is a tuple of ($\beta_1$, $\beta_2$)
• eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
• weight_decay is an instance of class WeightDecay defined in __init__.py
• ‘optimized_update’ is a flag whether to optimize the bias correction of the second moment by doing it after adding $\epsilon$
• amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
• degenerate_to_sgd whether to use sgd when the rectification term $r_t is intractable • ‘rectify’ is whether to use RAdam update • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdaBelief. 52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, 53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False, 54 degenerate_to_sgd=True, 55 rectify=True, defaults=None): 73 defaults = {} if defaults is None else defaults 74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults) 75 self.rectify = rectify ### Initialize a parameter state • state is the optimizer state of the parameter (tensor) • group stores optimizer attributes of the parameter group • param is the parameter tensor$\theta_{t-1}$77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter): 85 state['step'] = 0 Exponential moving average of gradient values 87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) Exponential moving average of variance 89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format) If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of variance 93 if group['amsgrad']: Maintains max of all exp. moving avg. of sq. grad. values 95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format) ### Calculate$m_t$and$s_t$or$\max(s_1, s_2, …, s_{t-1}, s_t)$• state is the optimizer state of the parameter (tensor) • group stores optimizer attributes of the parameter group • grad is the current gradient tensor$g_t$for the parameter$\theta_{t-1}$97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): Get$\beta_1$and$\beta_2$107 beta1, beta2 = group['betas'] Get$m_{t-1}$and$s_{t-1}$110 m, s = state['exp_avg'], state['exp_avg_var'] In-place calculation of$m_t$114 m.mul_(beta1).add_(grad, alpha=1 - beta1) Difference between gradient and momentum 116 grad_residual = grad - m In-place calculation of$s_t$119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) If this parameter group is using amsgrad 122 if group['amsgrad']: Get$\max(s_1, s_2, …, s_{t-1})$. 124 s_max = state['max_exp_avg_var'] Calculate$\max(s_1, s_2, …, s_{t-1}, s_t)$. 126 torch.maximum(s_max, s, out=s_max) 127 128 return m, s_max 129 else:$m_t$and$s_t$otherwise 131 return m, s ### Take an update step for a given parameter tensor • state is the optimizer state of the parameter (tensor) • group stores optimizer attributes of the parameter group • grad is the current gradient tensor$g_t$for the parameter$\theta_{t-1}$• param is the parameter tensor$\theta_{t-1}$133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): Calculate weight decay 144 grad = self.weight_decay(param, grad, group) Get$m_t$and$v_t$147 m, s = self.get_ms(state, group, grad) Increment$t$the number of optimizer steps 150 state['step'] += 1 151 152 if not self.rectify: Perform Adam update, defined in adam.py, with$\color{cyan}{s_t} + \color{red}{\epsilon}$in place of$v_t$. 155 self.adam_update(state, group, param, m, s + group['eps']) 156 else: Perform Rectified Adam update defined in radam.py, with$\color{cyan}{s_t} + \color{red}{\epsilon}$in place of$v_t\$.

159            self.r_adam_update(state, group, param, m, s + group['eps'])