亚当优化器

这是论文《:随机优化方法》中流行的优化器 Adam 的 Py Torch 实现。

亚当的更新是,

其中是标量超级参数。是一阶和二阶时刻。并且是有偏差的校正时刻。用作除以零误差的修复,但也用作对梯度方差起作用的超参数的一种形式。

假设采取的有效步骤是,这受限于、何时以及其他方面。在大多数常见情况下,

40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

亚当优化器

我们扩展了中GenericAdaptiveOptimizer 定义的类__init__.py 来实现 Adam 优化器。

50class Adam(GenericAdaptiveOptimizer):

初始化优化器

  • params 是参数列表
  • lr 是学习率
  • betas 是 (,) 的元组
  • eps基于optimized_update
  • weight_decay 是在中WeightDecay 定义的类的实例 __init__.py
  • optimized_update 是一个标志,是否在添加后通过这样做来优化第二个时刻的偏差校正
  • defaults 是组值的默认字典。当你想扩展类时,这很有用Adam
58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):
76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update

初始化参数状态

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • param 是参数张量
83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

这是优化器对参数采取的步骤数,

93        state['step'] = 0

梯度的指数移动平均线,

95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

梯度平方值的指数移动平均线,

97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

计算和和

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • grad 是参数的当前梯度张量
99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

获取

109        beta1, beta2 = group['betas']

获取

112        m, v = state['exp_avg'], state['exp_avg_sq']

就地计算

116        m.mul_(beta1).add_(grad, alpha=1 - beta1)

就地计算

119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v

获取学习率

这将根据状态返回修改后的学习速率。对于 Adam 来说,这只是参数组的指定学习速率

123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131        return group['lr']

Adam 参数是否更新

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • param 是参数张量
  • m 并且v 是未校正的第一和第二时刻,以及.

这计算出以下内容

是标量,其他是张量,因此我们将此计算修改为优化计算。

wher e 是我们应该指定为超参数的内容。

133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):

获取

166        beta1, beta2 = group['betas']

偏差校正术语

168        bias_correction1 = 1 - beta1 ** state['step']

偏差校正术语

170        bias_correction2 = 1 - beta2 ** state['step']

获取学习率

173        lr = self.get_lr(state, group)

是否优化计算

176        if self.optimized_update:

178            denominator = v.sqrt().add_(group['eps'])

180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1

183            param.data.addcdiv_(m, denominator, value=-step_size)

无需优化的计算

185        else:

187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

189            step_size = lr / bias_correction1

192            param.data.addcdiv_(m, denominator, value=-step_size)

对给定参数张量执行更新步骤

  • state 是参数(张量)的优化器状态
  • group 存储参数组的优化程序属性
  • grad 是参数的当前梯度张量
  • param 是参数张量
194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

计算体重衰减

205        grad = self.weight_decay(param, grad, group)

获取

208        m, v = self.get_mv(state, group, grad)

增加优化器步数

211        state['step'] += 1

执行 Adam 更新

214        self.adam_update(state, group, param, m, v)