This is the PyTorch implementation of optimizer introduced in the paper Attention Is All You Need.
14from typing import Dict
15
16from labml_nn.optimizers import WeightDecay
17from labml_nn.optimizers.amsgrad import AMSGrad
params
is the list of parameters lr
is the learning rate betas
is a tuple of (, ) eps
is or based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adam warmup
number of warmup steps d_model
model size; i.e. number of dimensions in the transformer defaults
is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup
.27 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
28 weight_decay: WeightDecay = WeightDecay(),
29 optimized_update: bool = True,
30 amsgrad=False,
31 warmup=0, d_model=512, defaults=None):
49 defaults = {} if defaults is None else defaults
50 defaults.update(dict(warmup=warmup))
51 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
52 self.d_model = d_model
54 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
62 factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))
64 return group['lr'] * self.d_model ** (-0.5) * factor
67def _test_noam_lr():
73 import matplotlib.pyplot as plt
74 import numpy as np
75 from torch import nn
76
77 model = nn.Linear(10, 10)
78 opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),
79 Noam(model.parameters(), d_model=512, warmup=8000, lr=1),
80 Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]
81 plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])
82 plt.legend(["512:4000", "512:8000", "2048:2000"])
83 plt.title("Learning Rate")
84 plt.show()
85
86
87if __name__ == '__main__':
88 _test_noam_lr()