Adam Optimizer

This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.

Adam update is,

where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

Effective step taken assuming is, This is bounded by, when and otherwise. And in most common scenarios,

40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

Adam Optimizer

We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Adam optimizer.

50class Adam(GenericAdaptiveOptimizer):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate
  • betas is a tuple of (, )
  • eps is or based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • optimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .
58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):
76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor
83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

This is the number of optimizer steps taken on the parameter,

93        state['step'] = 0

Exponential moving average of gradients,

95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Exponential moving average of squared gradient values,

97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Calculate and and

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

Get and

109        beta1, beta2 = group['betas']

Get and

112        m, v = state['exp_avg'], state['exp_avg_sq']

In-place calculation of

116        m.mul_(beta1).add_(grad, alpha=1 - beta1)

In-place calculation of

119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v

Get learning-rate

This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, .

123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131        return group['lr']

Do the Adam parameter update

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor
  • m and v are the uncorrected first and second moments and .

This computes the following

Since , , and are scalars and others are tensors we modify this calculation to optimize the computation.

where is what we should specify as the hyper-parameter.

133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):

Get and

166        beta1, beta2 = group['betas']

Bias correction term for ,

168        bias_correction1 = 1 - beta1 ** state['step']

Bias correction term for ,

170        bias_correction2 = 1 - beta2 ** state['step']

Get learning rate

173        lr = self.get_lr(state, group)

Whether to optimize the computation

176        if self.optimized_update:

178            denominator = v.sqrt().add_(group['eps'])

180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1

183            param.data.addcdiv_(m, denominator, value=-step_size)

Computation without optimization

185        else:

187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

189            step_size = lr / bias_correction1

192            param.data.addcdiv_(m, denominator, value=-step_size)

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
  • param is the parameter tensor
194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Calculate weight decay

205        grad = self.weight_decay(param, grad, group)

Get and

208        m, v = self.get_mv(state, group, grad)

Increment the number of optimizer steps

211        state['step'] += 1

Perform Adam update

214        self.adam_update(state, group, param, m, v)