Adam Optimizer with Warmup

This extends AMSGrad optimizer and adds a warmup stage.

12from typing import Dict
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad

Adam Optimizer with Warmup

This class extends from AMSGrad optimizer defined in

18class AdamWarmup(AMSGrad):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate $\alpha$
  • betas is a tuple of ($\beta_1$, $\beta_2$)
  • eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in
  • ‘optimized_update’ is a flag whether to optimize the bias correction of the second moment by doing it after adding $\epsilon$
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • warmup number of warmup steps
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup.
24    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
25                 weight_decay: WeightDecay = WeightDecay(),
26                 optimized_update: bool = True,
27                 amsgrad=False, warmup=0, defaults=None):
44        defaults = {} if defaults is None else defaults
45        defaults.update(dict(warmup=warmup))
46        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

Get learning-rate

where $w$ is the number of warmup steps.

48    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):

If we are in warmup stage

56        if group['warmup'] > state['step']:

A linearly increasing learning rate from $0$ to $\alpha$

58            return 1e-8 + state['step'] * group['lr'] / group['warmup']
59        else:

Constant learning rate $\alpha$

61            return group['lr']