ウォームアップ機能付き Adam オプティマイザー

これにより AMSgrad オプティマイザが拡張され、ウォームアップステージが追加されます。

12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad

ウォームアップ機能付き Adam オプティマイザー

このクラスは、で定義されている AMSGrad オプティマイザを拡張したものです。amsgrad.py

18class AdamWarmup(AMSGrad):

オプティマイザを初期化

  • params はパラメータのリストです
  • lr は学習率
  • betas (,) のタプルです
  • eps またはそれに基づいている optimized_update
  • weight_decay WeightDecay で定義されているクラスのインスタンスです __init__.py
  • 'optimized_update'は追加後に行うことでセカンドモーメントのバイアス補正を最適化するかどうかのフラグです
  • amsgrad amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグです
  • warmup ウォームアップステップ数
  • defaults グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdamWarmup
24    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
25                 weight_decay: WeightDecay = WeightDecay(),
26                 optimized_update: bool = True,
27                 amsgrad=False, warmup=0, defaults=None):
44        defaults = {} if defaults is None else defaults
45        defaults.update(dict(warmup=warmup))
46        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

学習率を取得

はウォームアップステップの数です。

48    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):

ウォームアップ段階の場合

56        if group['warmup'] > state['step']:

学習率が 1 から 1 に直線的に増加している

58            return 1e-8 + state['step'] * group['lr'] / group['warmup']
59        else:

一定の学習率

61            return group['lr']