AdaBelief Optimizer

This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.

This is implemented in PyTorch as an extension to RAdam.

The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.

🤔 The paper calculates variance as , but I feel it should use the bias corrected momentum . I guess this doesn't affect things much because bias correction is after the initial training steps.

36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam

AdaBelief Optimizer

This class extends from RAdam optimizer defined in radam.py .

45class AdaBelief(RAdam):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate
  • betas is a tuple of (, )
  • eps is or based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • optimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • degenerate_to_sgd whether to use sgd when the rectification term is intractable
  • rectify is whether to use RAdam update
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdaBelief .
52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):
73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectify

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor
77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85        state['step'] = 0

Exponential moving average of gradient values

87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Exponential moving average of variance

89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of variance

93        if group['amsgrad']:

Maintains max of all exp. moving avg. of sq. grad. values

95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Calculate and or

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

Get and

107        beta1, beta2 = group['betas']

Get and

110        m, s = state['exp_avg'], state['exp_avg_var']

In-place calculation of

114        m.mul_(beta1).add_(grad, alpha=1 - beta1)

Difference between gradient and momentum

116        grad_residual = grad - m

In-place calculation of

119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

If this parameter group is using amsgrad

122        if group['amsgrad']:

Get .

124            s_max = state['max_exp_avg_var']

Calculate .

126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:

and otherwise

131            return m, s

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
  • param is the parameter tensor
133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Calculate weight decay

144        grad = self.weight_decay(param, grad, group)

Get and

147        m, s = self.get_ms(state, group, grad)

Increment the number of optimizer steps

150        state['step'] += 1
151
152        if not self.rectify:

Perform Adam update, defined in adam.py , with in place of .

155            self.adam_update(state, group, param, m, s + group['eps'])
156        else:

Perform Rectified Adam update defined in radam.py , with in place of .

159            self.r_adam_update(state, group, param, m, s + group['eps'])