これは、「Adabeliefオプティマイザー:観測された勾配を信じてステップサイズを調整する」という論文のAdableLief公式実装に基づいています。
これは RadAM の拡張機能として PyTorch に実装されています。
Adam オプティマイザーと Adabelief の主な違いは、適応型学習率の計算方法にあります。Adabelief では、勾配の 2 乗の指数移動平均で割るのではなく、指数関数的分散平均で除算されます。
🤔 論文では分散を次のように計算していますが、バイアス補正されたモメンタムを使用すべきだと思います。バイアス補正は最初のトレーニングステップの後に行われるので、これはあまり影響しないと思います。
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam
params
はパラメータのリストですlr
は学習率 betas
(,) のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
optimized_update
セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです amsgrad
amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグですdegenerate_to_sgd
修正項が扱いにくい場合に sgd を使うかどうか rectify
RadAMアップデートを使用するかどうかですdefaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdaBelief
。52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):
73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectify
77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85 state['step'] = 0
勾配値の指数移動平均
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
指数移動平均偏差
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
amsgrad
True
このパラメータグループにフラグを指定すると、指数移動平均の最大分散値が維持されます。
93 if group['amsgrad']:
すべての許容偏差移動平均値の最大値を維持
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
取得して
107 beta1, beta2 = group['betas']
取得して
110 m, s = state['exp_avg'], state['exp_avg_var']
のインプレース計算
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)
勾配と運動量の違い
116 grad_residual = grad - m
のインプレース計算
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
このパラメータグループが使用している場合 amsgrad
122 if group['amsgrad']:
取得。
124 s_max = state['max_exp_avg_var']
計算。
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:
それ以外は
131 return m, s
state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますgrad
パラメータの現在の勾配テンソルです param
はパラメータテンソル 133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
体重減少の計算
144 grad = self.weight_decay(param, grad, group)
取得して
147 m, s = self.get_ms(state, group, grad)
オプティマイザーのステップ数を増やす
150 state['step'] += 1
151
152 if not self.rectify:
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: