アダブリリーフオプティマイザー

これは、「Adabeliefオプティマイザー:観測された勾配を信じてステップサイズを調整する」という論文のAdableLief公式実装に基づいています

これは RadAM の拡張機能として PyTorch に実装されています。

Adam オプティマイザーと Adabelief の主な違いは、適応型学習率の計算方法にあります。Adabelief では、勾配の 2 乗の指数移動平均で割るのではなく、指数関数的分散平均で除算されます。

🤔 論文では分散を次のように計算していますが、バイアス補正されたモメンタムを使用すべきだと思います。バイアス補正は最初のトレーニングステップの後に行われるので、これはあまり影響しないと思います。

36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam

アダブリリーフオプティマイザー

このクラスは、で定義されている RadAM オプティマイザを拡張したものです。radam.py

45class AdaBelief(RAdam):

オプティマイザを初期化

  • params はパラメータのリストです
  • lr は学習率
  • betas (,) のタプルです
  • eps またはそれに基づいている optimized_update
  • weight_decay WeightDecay で定義されているクラスのインスタンスです __init__.py
  • optimized_update セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです
  • amsgrad amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグです
  • degenerate_to_sgd 修正項が扱いにくい場合に sgd を使うかどうか
  • rectify RadAMアップデートを使用するかどうかです
  • defaults グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdaBelief
52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):
73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectify

パラメータ状態を初期化

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • param はパラメータテンソル
77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85        state['step'] = 0

勾配値の指数移動平均

87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

指数移動平均偏差

89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

amsgrad True このパラメータグループにフラグを指定すると、指数移動平均の最大分散値が維持されます。

93        if group['amsgrad']:

すべての許容偏差移動平均値の最大値を維持

95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

計算または

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • grad パラメータの現在の勾配テンソルです
97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

取得して

107        beta1, beta2 = group['betas']

取得して

110        m, s = state['exp_avg'], state['exp_avg_var']

のインプレース計算

114        m.mul_(beta1).add_(grad, alpha=1 - beta1)

勾配と運動量の違い

116        grad_residual = grad - m

のインプレース計算

119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

このパラメータグループが使用している場合 amsgrad

122        if group['amsgrad']:

取得

124            s_max = state['max_exp_avg_var']

計算

126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:

それ以外は

131            return m, s

与えられたパラメータテンソルの更新ステップを実行する

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • grad パラメータの現在の勾配テンソルです
  • param はパラメータテンソル
133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

体重減少の計算

144        grad = self.weight_decay(param, grad, group)

取得して

147        m, s = self.get_ms(state, group, grad)

オプティマイザーのステップ数を増やす

150        state['step'] += 1
151
152        if not self.rectify:

の代わりにadam.py で定義されている Adam 更新を実行します

155            self.adam_update(state, group, param, m, s + group['eps'])
156        else:

で定義されている修正済みの Adam 更新をradam.py の代わりにで実行します。

159            self.r_adam_update(state, group, param, m, s + group['eps'])