Optimizers

Optimizer Implementations

This MNIST example uses these optimizers.

Generic Adaptive Optimizer Base class and Weight Decay

This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.

We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.

Here are some concepts on PyTorch optimizers:

Parameter groups

PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.

In most common cases there will be only one group. This is when you initialize your optimizer with,

Optimizer(model.parameters())

You can define multiple parameter groups when initializing the optimizer:

Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])

Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.

You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups . Most learning rate schedule implementations I've come across do access this and change 'lr'.

States

Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state . This is where the optimizer maintains things like exponential averages.

63from typing import Dict, Tuple, Any
64
65import torch
66from torch import nn
67from torch.optim.optimizer import Optimizer

Adam 和扩展的基类

70class GenericAdaptiveOptimizer(Optimizer):

初始化

  • params 是参数的集合或一组参数组。
  • defaults 默认超参数的字典
  • lr 是学习率,
  • betas 是元组
  • eps
  • 75    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):

    检查超参数

    87        if not 0.0 <= lr:
    88            raise ValueError(f"Invalid learning rate: {lr}")
    89        if not 0.0 <= eps:
    90            raise ValueError(f"Invalid epsilon value: {eps}")
    91        if not 0.0 <= betas[0] < 1.0:
    92            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    93        if not 0.0 <= betas[1] < 1.0:
    94            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

    将超参数添加到默认值

    97        defaults.update(dict(lr=lr, betas=betas, eps=eps))

    初始化 PyTorch 优化器。这将使用默认的超参数创建参数组

    100        super().__init__(params, defaults)

    初始化给定参数张量的状态

    这应该被代码覆盖,以便初始state 化参数paramgroup 是所param 属的参数组字典。

    102    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
    109        pass

    在参数张量上采取优化器步骤

    这应该被重写并对param 张量采取优化步骤,其中grad 是该参数的梯度state 是该参数的优化器状态字典,group 也是参数组字典param 所属的。

    111    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
    120        pass

    优化器步骤

    我们创建了一个模板方法,它可以完成每个基于 Adam 的优化器所需要的常用内容。

    122    @torch.no_grad()
    123    def step(self, closure=None):

    计算损失。

    🤔 我不确定你什么时候需要这个。我想如果你定义一个函数来计算损失,做loss.backward 和返回损失,而不是自己调用它,你可以传递给它optimizer.step 。🤷‍♂️

    134        loss = None
    135        if closure is not None:
    136            with torch.enable_grad():
    137                loss = closure()

    遍历参数组

    140        for group in self.param_groups:

    遍历参数组中的参数

    142            for param in group['params']:

    如果参数没有渐变,则跳过

    144                if param.grad is None:
    145                    continue

    获取梯度张量

    147                grad = param.grad.data

    我们不处理稀疏渐变

    149                if grad.is_sparse:
    150                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
    151                                       ' please consider SparseAdam instead')

    获取参数的状态

    154                state = self.state[param]

    如果状态未初始化,则初始化状态

    157                if len(state) == 0:
    158                    self.init_state(state, group, param)

    对参数采取优化步骤

    161                self.step_param(state, group, grad, param)

    返回从闭包计算得出的损失

    164        return loss

    L2 重量衰减

    167class WeightDecay:

    初始化权重衰减

    • weight_decay 是衰减系数
    • weight_decouple 是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。
    • absolute 此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为weight_decay
    • learning_rate
    172    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):

    检查超参数

    185        if not 0.0 <= weight_decay:
    186            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
    187
    188        self.absolute = absolute
    189        self.weight_decouple = weight_decouple
    190        self.weight_decay = weight_decay

    返回参数组的默认值

    192    def defaults(self):
    196        return dict(weight_decay=self.weight_decay)

    执行权重衰减并返回梯度

    198    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):

    如果我们直接对参数进行衰减

    204        if self.weight_decouple:

    如果权重衰减系数为绝对值

    206            if self.absolute:
    207                param.data.mul_(1.0 - group['weight_decay'])

    否则,

    209            else:
    210                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])

    返回未修改的渐变

    212            return grad
    213        else:
    214            if group['weight_decay'] != 0:

    将权重衰减添加到渐变并返回修改后的渐变

    216                return grad.add(param.data, alpha=group['weight_decay'])
    217            else:
    218                return grad