マスグラード

これは、論文「アダムの収束と彼方」をPyTorchで実装したものです

これを Adam オプティマイザー実装の拡張として実装します。Adamと非常に似ているので、実装自体は非常に小さいです。

また、論文で説明したAdamが収束しない合成例の実装もあります。

18from typing import Dict
19
20import torch
21from torch import nn
22
23from labml_nn.optimizers import WeightDecay
24from labml_nn.optimizers.adam import Adam

マスグラードオプティマイザー

このクラスは、で定義されている Adam オプティマイザを拡張したものです。adam.py Adam GenericAdaptiveOptimizer オプティマイザはで定義されているクラスを拡張しています

__init__.py
27class AMSGrad(Adam):

オプティマイザを初期化

  • params はパラメータのリストです
  • lr は学習率
  • betas (,) のタプルです
  • eps またはそれに基づいている optimized_update
  • weight_decay WeightDecay で定義されているクラスのインスタンスです __init__.py
  • 'optimized_update'は追加後に行うことでセカンドモーメントのバイアス補正を最適化するかどうかのフラグです
  • amsgrad amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグです
  • defaults グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdam
35    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
36                 weight_decay: WeightDecay = WeightDecay(),
37                 optimized_update: bool = True,
38                 amsgrad=True, defaults=None):
53        defaults = {} if defaults is None else defaults
54        defaults.update(dict(amsgrad=amsgrad))
55
56        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)

パラメータ状態を初期化

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • param はパラメータテンソル
58    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

init_state 拡張中のCall of Adamオプティマイザー

68        super().init_state(state, group, param)

amsgrad このパラメータグループにフラグを指定すると、二乗勾配の指数移動平均の最大値が維持されます。True

72        if group['amsgrad']:
73            state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

計算およびまたは

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • grad パラメータの現在の勾配テンソルです
75    def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):

アダムから入手して

85        m, v = super().get_mv(state, group, grad)

このパラメータグループが使用している場合 amsgrad

88        if group['amsgrad']:

取得

🗒 この論文ではこの表記法を使用していますが、ここでは使用しません。これは、アダムがバイアス補正された指数移動平均について同じ表記法を使用することと混同するためです。

94            v_max = state['max_exp_avg_sq']

計算

🤔 二乗勾配のバイアス補正後の第二指数平均の最大値をとる/維持すべきだと思います。しかし、PyTorchでもこのように実装されています。バイアス補正は値を増やすだけで、トレーニングの初期の数ステップで実際に違いが出るだけなので、それほど重要ではないと思います。

103            torch.maximum(v_max, v, out=v_max)
104
105            return m, v_max
106        else:

パラメータグループが使用していない場合は Adam にフォールバックします。amsgrad

108            return m, v

合成実験

これは論文で説明されている合成実験で、アダムが失敗するシナリオを示しています

論文(とアダム)は、最適化の問題を、パラメーターに関する関数の期待値を最小化することとして定式化しています。 確率的トレーニングの設定では、関数自体を把握することはできません。つまり、最適化すると、NN はデータのバッチ全体に対する関数になります。実際に評価するのはミニバッチなので、実際の関数は確率論の実現です。これが期待値について話している理由です。そこで、機能の実現をトレーニングの各タイムステップで行うとしましょう

オプティマイザの性能を後悔として測定します。ここで、はタイムステップでのパラメータは最小化する最適なパラメータです。

それでは、総合的な問題を定義しましょう。

どこ。最適な解決策はです

このコードでは、この問題に対して Adam と AmsGrad を実行してみます。

111def _synthetic_experiment(is_adam: bool):

パラメータを定義

153    x = nn.Parameter(torch.tensor([.0]))

最適、

155    x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)

157    def func(t: int, x_: nn.Parameter):
161        if t % 101 == 1:
162            return (1010 * x_).sum()
163        else:
164            return (-10 * x_).sum()

関連するオプティマイザーを初期化します

167    if is_adam:
168        optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))
169    else:
170        optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))

172    total_regret = 0
173
174    from labml import monit, tracker, experiment

テストを作成して結果を記録する

177    with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):

ランニング・フォー・ステップス

179        for step in monit.loop(10_000_000):

181            regret = func(step, x) - func(step, x_star)

183            total_regret += regret.item()

1,000 ステップごとに結果をトラッキング

185            if (step + 1) % 1000 == 0:
186                tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))

勾配の計算

188            regret.backward()

最適化

190            optimizer.step()

クリアグラデーション

192            optimizer.zero_grad()

確認して

195            x.data.clamp_(-1., +1.)
196
197
198if __name__ == '__main__':

合成実験を実行するのはアダムです。アダムが次の場所に収束しているのがわかります

201    _synthetic_experiment(True)

合成実験を実行するとamsGradです。amsGradが真最適に収束することがわかります

204    _synthetic_experiment(False)