この MNIST の例では、これらのオプティマイザーを使用しています。
このファイルは、Adam の共通基本クラスとその拡張を定義しています。基本クラスは、再利用が可能なため、最小限のコードで他のオプティマイザを実装するのに役立ちます
。また、L2の重み減衰用の特別なクラスを定義しているので、各オプティマイザー内に実装する必要がなく、オプティマイザーを変更せずにL1のような他の重み減衰にも簡単に拡張できます。
PyTorch オプティマイザの概念は次のとおりです。
PyTorch オプティマイザーは、パラメーターをグループと呼ばれるセットにグループ化します。各グループには、学習率などの独自のハイパーパラメータを設定できます
。たいていの場合、グループが 1 つしかありません。このとき、オプティマイザを次のように初期化します
。Optimizer(model.parameters())
オプティマイザを初期化するときに、複数のパラメータグループを定義できます。
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
ここにグループのリストを渡します。各グループは辞書で、パラメータは 'params' です。任意のハイパーパラメータも指定します。ハイパーパラメータが定義されていない場合は、デフォルトでオプティマイザレベルのデフォルトになります
。を使用してこれらのグループとそのハイパーパラメータにアクセスしたり、変更したりすることができます。optimizer.param_groups
私が出会ったほとんどの学習率スケジュールの実装は、これにアクセスして「lr」を変更します
オプティマイザーは、各パラメーター (テンソル) の状態 (辞書) を辞書に保持します。optimizer.state
ここで、オプティマイザーは指数平均などを管理します
62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer
69class GenericAdaptiveOptimizer(Optimizer):
params
パラメータのコレクションまたはパラメータグループのセットです。defaults
デフォルトのハイパーパラメータの辞書lr
は学習率 betas
はタプルです eps
は 74 def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):
ハイパーパラメータを確認
86 if not 0.0 <= lr:
87 raise ValueError(f"Invalid learning rate: {lr}")
88 if not 0.0 <= eps:
89 raise ValueError(f"Invalid epsilon value: {eps}")
90 if not 0.0 <= betas[0] < 1.0:
91 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
92 if not 0.0 <= betas[1] < 1.0:
93 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
ハイパーパラメータをデフォルトに追加
96 defaults.update(dict(lr=lr, betas=betas, eps=eps))
PyTorch オプティマイザーを初期化します。これにより、デフォルトのハイパーパラメータを使用してパラメータグループが作成されます
99 super().__init__(params, defaults)
state
これをオーバーライドしてパラメータを初期化するコードを使うべきです。param
group
param
が属するパラメータグループディクショナリです。
101 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
108 pass
これをオーバーライドして、param
テンソルで最適化ステップを実行する必要があります。ここでgrad
、はそのパラメーターの勾配、はそのパラメーターのオプティマイザー状態ディクショナリ、state
group
はディクショナリが属するパラメーターグループです。 param
110 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
119 pass
121 @torch.no_grad()
122 def step(self, closure=None):
損失を計算します。
🤔 いつこれが必要なのかわかりません。自分で呼び出すのではなく、loss.backward
損失を計算して損失を出して返す関数を定義すれば、その関数を渡せると思いますoptimizer.step
。🤷♂️
133 loss = None
134 if closure is not None:
135 with torch.enable_grad():
136 loss = closure()
パラメータグループを繰り返し処理する
139 for group in self.param_groups:
パラメータグループ内のパラメータを繰り返し処理します
141 for param in group['params']:
パラメータにグラデーションがない場合はスキップ
143 if param.grad is None:
144 continue
勾配テンソルを取得
146 grad = param.grad.data
スパースグラデーションは扱いません
148 if grad.is_sparse:
149 raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
150 ' please consider SparseAdam instead')
パラメータの状態を取得
153 state = self.state[param]
状態が初期化されていない場合は状態を初期化します
156 if len(state) == 0:
157 self.init_state(state, group, param)
パラメータの最適化手順を実行してください
160 self.step_param(state, group, grad, param)
決済から計算した損失額を返金
163 return loss
166class WeightDecay:
weight_decay
は減衰係数weight_decouple
グラデーションにウェイトディケイを追加するか、パラメータから直接ディケイを加えるかを示すフラグです。グラデーションに追加すると、通常のオプティマイザーの更新が行われますabsolute
このフラグは重量減衰係数が絶対値かどうかを示します。これは、ディケイをパラメータに直接適用する場合に適用できます。これが false の場合、実際の減衰は weight_decay
learning_rate
。171 def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):
ハイパーパラメータをチェック
184 if not 0.0 <= weight_decay:
185 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
186
187 self.absolute = absolute
188 self.weight_decouple = weight_decouple
189 self.weight_decay = weight_decay
パラメータグループのデフォルト値を返す
191 def defaults(self):
195 return dict(weight_decay=self.weight_decay)
197 def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
パラメータで直接ディケイを行う場合
203 if self.weight_decouple:
重量減衰係数が絶対値の場合
205 if self.absolute:
206 param.data.mul_(1.0 - group['weight_decay'])
それ以外の場合は、
208 else:
209 param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
変更されていないグラデーションを返す
211 return grad
212 else:
213 if group['weight_decay'] != 0:
グラデーションにウェイトディケイを追加し、変更したグラデーションを返します。
215 return grad.add(param.data, alpha=group['weight_decay'])
216 else:
217 return grad