AMSGrad

This is a PyTorch implementation of the paper On the Convergence of Adam and Beyond.

We implement this as an extension to our Adam optimizer implementation. The implementation it self is really small since it's very similar to Adam.

We also have an implementation of the synthetic example described in the paper where Adam fails to converge.

18from typing import Dict
19
20import torch
21from torch import nn
22
23from labml_nn.optimizers import WeightDecay
24from labml_nn.optimizers.adam import Adam

AMSGrad Optimizer

This class extends from Adam optimizer defined in adam.py . Adam optimizer is extending the class GenericAdaptiveOptimizer defined in __init__.py .

27class AMSGrad(Adam):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate
  • betas is a tuple of (, )
  • eps is or based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .
35    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
36                 weight_decay: WeightDecay = WeightDecay(),
37                 optimized_update: bool = True,
38                 amsgrad=True, defaults=None):
53        defaults = {} if defaults is None else defaults
54        defaults.update(dict(amsgrad=amsgrad))
55
56        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor
58    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

Call init_state of Adam optimizer which we are extending

68        super().init_state(state, group, param)

If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of squared gradient

72        if group['amsgrad']:
73            state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Calculate and and or

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
75    def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):

Get and from Adam

85        m, v = super().get_mv(state, group, grad)

If this parameter group is using amsgrad

88        if group['amsgrad']:

Get .

🗒 The paper uses the notation for this, which we don't use that here because it confuses with the Adam's usage of the same notation for bias corrected exponential moving average.

94            v_max = state['max_exp_avg_sq']

Calculate .

🤔 I feel you should be taking / maintaining the max of the bias corrected second exponential average of squared gradient. But this is how it's implemented in PyTorch also. I guess it doesn't really matter since bias correction only increases the value and it only makes an actual difference during the early few steps of the training.

103            torch.maximum(v_max, v, out=v_max)
104
105            return m, v_max
106        else:

Fall back to Adam if the parameter group is not using amsgrad

108            return m, v

Synthetic Experiment

This is the synthetic experiment described in the paper, that shows a scenario where Adam fails.

The paper (and Adam) formulates the problem of optimizing as minimizing the expected value of a function, with respect to the parameters . In the stochastic training setting we do not get hold of the function it self; that is, when you are optimizing a NN would be the function on entire batch of data. What we actually evaluate is a mini-batch so the actual function is realization of the stochastic . This is why we are talking about an expected value. So let the function realizations be for each time step of training.

We measure the performance of the optimizer as the regret, where is the parameters at time step , and is the optimal parameters that minimize .

Now lets define the synthetic problem,

where . The optimal solution is .

This code will try running Adam and AMSGrad on this problem.

111def _synthetic_experiment(is_adam: bool):

Define parameter

153    x = nn.Parameter(torch.tensor([.0]))

Optimal,

155    x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)

157    def func(t: int, x_: nn.Parameter):
161        if t % 101 == 1:
162            return (1010 * x_).sum()
163        else:
164            return (-10 * x_).sum()

Initialize the relevant optimizer

167    if is_adam:
168        optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))
169    else:
170        optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))

172    total_regret = 0
173
174    from labml import monit, tracker, experiment

Create experiment to record results

177    with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):

Run for steps

179        for step in monit.loop(10_000_000):

181            regret = func(step, x) - func(step, x_star)

183            total_regret += regret.item()

Track results every 1,000 steps

185            if (step + 1) % 1000 == 0:
186                tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))

Calculate gradients

188            regret.backward()

Optimize

190            optimizer.step()

Clear gradients

192            optimizer.zero_grad()

Make sure

195            x.data.clamp_(-1., +1.)
196
197
198if __name__ == '__main__':

Run the synthetic experiment is Adam. You can see that Adam converges at

201    _synthetic_experiment(True)

Run the synthetic experiment is AMSGrad You can see that AMSGrad converges to true optimal

204    _synthetic_experiment(False)