アダム・オプティマイザー

これは、論文「アダム:確率的最適化の方法」に掲載された人気のオプティマイザーAdamをPyTorchで実装したものです

アダムのアップデートは

ここでおよびはスカラーのハイパーパラメーターです。ファーストオーダー、セカンドオーダーの瞬間です 偏り修正されたモーメントです。ゼロエラーによる除算の修正として使われますが、勾配のばらつきに対して作用するハイパーパラメータの形式としても機能します

有効な手順は、「This が制限される」、「いつ」、「それ以外の場合」を前提としています。そして、最も一般的なシナリオでは、

40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

アダム・オプティマイザー

GenericAdaptiveOptimizer __init__.py で定義したクラスを拡張して Adam オプティマイザを実装します。

50class Adam(GenericAdaptiveOptimizer):

オプティマイザを初期化

  • params はパラメータのリストです
  • lr は学習率
  • betas (,) のタプルです
  • eps またはそれに基づいている optimized_update
  • weight_decay WeightDecay で定義されているクラスのインスタンスです __init__.py
  • optimized_update セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです
  • defaults グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdam
58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):
76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update

パラメータ状態を初期化

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • param はパラメータテンソル
83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

これは、パラメーターに対して実行されたオプティマイザーステップの数です。

93        state['step'] = 0

勾配の指数移動平均、

95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

二乗勾配値の指数移動平均、

97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

計算と

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • grad パラメータの現在の勾配テンソルです
99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

取得して

109        beta1, beta2 = group['betas']

取得して

112        m, v = state['exp_avg'], state['exp_avg_sq']

のインプレース計算

116        m.mul_(beta1).add_(grad, alpha=1 - beta1)

のインプレース計算

119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v

学習率を取得

これにより、状態に基づいて修正された学習率が返されます。Adam の場合、これはパラメータグループに指定されている学習率にすぎません

123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131        return group['lr']

Adam パラメータを更新してください

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • param はパラメータテンソル
  • m v そして未修正の第一瞬間と第二瞬間と

これにより、以下が計算されます

はスカラーで、その他はテンソルなので、この計算を変更して計算を最適化します。

ここで、ハイパーパラメータとして指定する必要があります。

133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):

取得して

166        beta1, beta2 = group['betas']

のバイアス補正用語

168        bias_correction1 = 1 - beta1 ** state['step']

のバイアス補正用語

170        bias_correction2 = 1 - beta2 ** state['step']

学習率を取得

173        lr = self.get_lr(state, group)

計算を最適化するかどうか

176        if self.optimized_update:

178            denominator = v.sqrt().add_(group['eps'])

180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1

183            param.data.addcdiv_(m, denominator, value=-step_size)

最適化なしの計算

185        else:

187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

189            step_size = lr / bias_correction1

192            param.data.addcdiv_(m, denominator, value=-step_size)

与えられたパラメータテンソルの更新ステップを実行する

  • state はパラメーター (テンソル) のオプティマイザー状態です
  • group パラメータグループのオプティマイザ属性を格納します
  • grad パラメータの現在の勾配テンソルです
  • param はパラメータテンソル
194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

体重減少の計算

205        grad = self.weight_decay(param, grad, group)

取得して

208        m, v = self.get_mv(state, group, grad)

オプティマイザーのステップ数を増やす

211        state['step'] += 1

Adam アップデートを実行

214        self.adam_update(state, group, param, m, v)