オプティマイザー

オプティマイザーの実装

この MNIST の例では、これらのオプティマイザーを使用しています。

汎用アダプティブオプティマイザー基本クラスとウェイトディケイ

このファイルは、Adam の共通基本クラスとその拡張を定義しています。基本クラスは、再利用が可能なため、最小限のコードで他のオプティマイザを実装するのに役立ちます

また、L2の重み減衰用の特別なクラスを定義しているので、各オプティマイザー内に実装する必要がなく、オプティマイザーを変更せずにL1のような他の重み減衰にも簡単に拡張できます。

PyTorch オプティマイザの概念は次のとおりです。

パラメータグループ

PyTorch オプティマイザーは、パラメーターをグループと呼ばれるセットにグループ化します。各グループには、学習率などの独自のハイパーパラメータを設定できます

たいていの場合、グループが 1 つしかありません。このとき、オプティマイザを次のように初期化します

Optimizer(model.parameters())

オプティマイザを初期化するときに、複数のパラメータグループを定義できます。

Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])

ここにグループのリストを渡します。各グループは辞書で、パラメータは 'params' です。任意のハイパーパラメータも指定します。ハイパーパラメータが定義されていない場合は、デフォルトでオプティマイザレベルのデフォルトになります

を使用してこれらのグループとそのハイパーパラメータにアクセスしたり、変更したりすることができます。optimizer.param_groups 私が出会ったほとんどの学習率スケジュールの実装は、これにアクセスして「lr」を変更します

オプティマイザーは、各パラメーター (テンソル) の状態 (辞書) を辞書に保持します。optimizer.state ここで、オプティマイザーは指数平均などを管理します

62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer

Adam と拡張機能の基底クラス

69class GenericAdaptiveOptimizer(Optimizer):

[初期化]

  • params パラメータのコレクションまたはパラメータグループのセットです。
  • defaults デフォルトのハイパーパラメータの辞書
  • lr は学習率
  • betas はタプルです
  • eps
  • 74    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):

    ハイパーパラメータを確認

    86        if not 0.0 <= lr:
    87            raise ValueError(f"Invalid learning rate: {lr}")
    88        if not 0.0 <= eps:
    89            raise ValueError(f"Invalid epsilon value: {eps}")
    90        if not 0.0 <= betas[0] < 1.0:
    91            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    92        if not 0.0 <= betas[1] < 1.0:
    93            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

    ハイパーパラメータをデフォルトに追加

    96        defaults.update(dict(lr=lr, betas=betas, eps=eps))

    PyTorch オプティマイザーを初期化します。これにより、デフォルトのハイパーパラメータを使用してパラメータグループが作成されます

    99        super().__init__(params, defaults)

    与えられたパラメータテンソルの状態を初期化

    state これをオーバーライドしてパラメータを初期化するコードを使うべきです。param group param が属するパラメータグループディクショナリです。

    101    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
    108        pass

    パラメーターテンソルでオプティマイザーステップを実行する

    これをオーバーライドして、param テンソルで最適化ステップを実行する必要があります。ここでgrad 、はそのパラメーターの勾配、はそのパラメーターのオプティマイザー状態ディクショナリ、state group はディクショナリが属するパラメーターグループです。 param

    110    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
    119        pass

    オプティマイザーステップ

    すべてのAdamベースのオプティマイザーが必要とする一般的な処理を行うテンプレートメソッドを作成しました

    121    @torch.no_grad()
    122    def step(self, closure=None):

    損失を計算します。

    🤔 いつこれが必要なのかわかりません。自分で呼び出すのではなく、loss.backward 損失を計算して損失を出して返す関数を定義すれば、その関数を渡せると思いますoptimizer.step 。🤷‍♂️

    133        loss = None
    134        if closure is not None:
    135            with torch.enable_grad():
    136                loss = closure()

    パラメータグループを繰り返し処理する

    139        for group in self.param_groups:

    パラメータグループ内のパラメータを繰り返し処理します

    141            for param in group['params']:

    パラメータにグラデーションがない場合はスキップ

    143                if param.grad is None:
    144                    continue

    勾配テンソルを取得

    146                grad = param.grad.data

    スパースグラデーションは扱いません

    148                if grad.is_sparse:
    149                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
    150                                       ' please consider SparseAdam instead')

    パラメータの状態を取得

    153                state = self.state[param]

    状態が初期化されていない場合は状態を初期化します

    156                if len(state) == 0:
    157                    self.init_state(state, group, param)

    パラメータの最適化手順を実行してください

    160                self.step_param(state, group, grad, param)

    決済から計算した損失額を返金

    163        return loss

    L2 ウェイト・ディケイ

    166class WeightDecay:

    体重減衰を初期化

    • weight_decay は減衰係数
    • weight_decouple グラデーションにウェイトディケイを追加するか、パラメータから直接ディケイを加えるかを示すフラグです。グラデーションに追加すると、通常のオプティマイザーの更新が行われます
    • absolute このフラグは重量減衰係数が絶対値かどうかを示します。これは、ディケイをパラメータに直接適用する場合に適用できます。これが false の場合、実際の減衰は weight_decay
  • learning_rate
  • 171    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):

    ハイパーパラメータをチェック

    184        if not 0.0 <= weight_decay:
    185            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
    186
    187        self.absolute = absolute
    188        self.weight_decouple = weight_decouple
    189        self.weight_decay = weight_decay

    パラメータグループのデフォルト値を返す

    191    def defaults(self):
    195        return dict(weight_decay=self.weight_decay)

    ウェイトディケイを実行してグラデーションを戻す

    197    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):

    パラメータで直接ディケイを行う場合

    203        if self.weight_decouple:

    重量減衰係数が絶対値の場合

    205            if self.absolute:
    206                param.data.mul_(1.0 - group['weight_decay'])

    それ以外の場合は、

    208            else:
    209                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])

    変更されていないグラデーションを返す

    211            return grad
    212        else:
    213            if group['weight_decay'] != 0:

    グラデーションにウェイトディケイを追加し、変更したグラデーションを返します。

    215                return grad.add(param.data, alpha=group['weight_decay'])
    216            else:
    217                return grad