12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad
18class AdamWarmup(AMSGrad):
params
はパラメータのリストですlr
は学習率 betas
(,) のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
amsgrad
amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグですwarmup
ウォームアップステップ数defaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdamWarmup
。24 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
25 weight_decay: WeightDecay = WeightDecay(),
26 optimized_update: bool = True,
27 amsgrad=False, warmup=0, defaults=None):
44 defaults = {} if defaults is None else defaults
45 defaults.update(dict(warmup=warmup))
46 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
48 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
ウォームアップ段階の場合
56 if group['warmup'] > state['step']:
学習率が 1 から 1 に直線的に増加している
58 return 1e-8 + state['step'] * group['lr'] / group['warmup']
59 else:
一定の学習率
61 return group['lr']