14from typing import Dict
15
16from labml_nn.optimizers import WeightDecay
17from labml_nn.optimizers.amsgrad import AMSGrad
params
はパラメータのリストですlr
は学習率 betas
(,) のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
amsgrad
amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグですwarmup
ウォームアップステップ数d_model
モデルサイズ、つまり変圧器の次元数defaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdamWarmup
。27 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
28 weight_decay: WeightDecay = WeightDecay(),
29 optimized_update: bool = True,
30 amsgrad=False,
31 warmup=0, d_model=512, defaults=None):
49 defaults = {} if defaults is None else defaults
50 defaults.update(dict(warmup=warmup))
51 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
52 self.d_model = d_model
54 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
62 factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))
64 return group['lr'] * self.d_model ** (-0.5) * factor
67def _test_noam_lr():
73 import matplotlib.pyplot as plt
74 import numpy as np
75 from torch import nn
76
77 model = nn.Linear(10, 10)
78 opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),
79 Noam(model.parameters(), d_model=512, warmup=8000, lr=1),
80 Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]
81 plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])
82 plt.legend(["512:4000", "512:8000", "2048:2000"])
83 plt.title("Learning Rate")
84 plt.show()
85
86
87if __name__ == '__main__':
88 _test_noam_lr()