これは、論文「アダム:確率的最適化の方法」に掲載された人気のオプティマイザーAdamをPyTorchで実装したものです。
アダムのアップデートは、
ここで、、およびはスカラーのハイパーパラメーターです。ファーストオーダー、セカンドオーダーの瞬間です 偏り修正されたモーメントです。ゼロエラーによる除算の修正として使われますが、勾配のばらつきに対して作用するハイパーパラメータの形式としても機能します
。有効な手順は、「This が制限される」、「いつ」、「それ以外の場合」を前提としています。そして、最も一般的なシナリオでは、
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
50class Adam(GenericAdaptiveOptimizer):
params
はパラメータのリストですlr
は学習率 betas
(,) のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
optimized_update
セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです defaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdam
。58 def __init__(self, params,
59 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60 weight_decay: WeightDecay = WeightDecay(),
61 optimized_update: bool = True,
62 defaults: Optional[Dict[str, Any]] = None):
76 defaults = {} if defaults is None else defaults
77 defaults.update(weight_decay.defaults())
78 super().__init__(params, defaults, lr, betas, eps)
79
80 self.weight_decay = weight_decay
81 self.optimized_update = optimized_update
83 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
これは、パラメーターに対して実行されたオプティマイザーステップの数です。
93 state['step'] = 0
勾配の指数移動平均、
95 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
二乗勾配値の指数移動平均、
97 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
99 def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
取得して
109 beta1, beta2 = group['betas']
取得して
112 m, v = state['exp_avg'], state['exp_avg_sq']
のインプレース計算
116 m.mul_(beta1).add_(grad, alpha=1 - beta1)
のインプレース計算
119 v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121 return m, v
123 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131 return group['lr']
state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますparam
はパラメータテンソル m
v
そして未修正の第一瞬間と第二瞬間と これにより、以下が計算されます
、はスカラーで、その他はテンソルなので、この計算を変更して計算を最適化します。
ここで、ハイパーパラメータとして指定する必要があります。
133 def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134 m: torch.Tensor, v: torch.Tensor):
取得して
166 beta1, beta2 = group['betas']
のバイアス補正用語
168 bias_correction1 = 1 - beta1 ** state['step']
のバイアス補正用語
170 bias_correction2 = 1 - beta2 ** state['step']
学習率を取得
173 lr = self.get_lr(state, group)
計算を最適化するかどうか
176 if self.optimized_update:
178 denominator = v.sqrt().add_(group['eps'])
180 step_size = lr * math.sqrt(bias_correction2) / bias_correction1
183 param.data.addcdiv_(m, denominator, value=-step_size)
最適化なしの計算
185 else:
187 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
189 step_size = lr / bias_correction1
192 param.data.addcdiv_(m, denominator, value=-step_size)
state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますgrad
パラメータの現在の勾配テンソルです param
はパラメータテンソル 194 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
体重減少の計算
205 grad = self.weight_decay(param, grad, group)
取得して
208 m, v = self.get_mv(state, group, grad)
オプティマイザーのステップ数を増やす
211 state['step'] += 1
Adam アップデートを実行
214 self.adam_update(state, group, param, m, v)