这是论文《亚当:随机优化方法》中流行的优化器 Adam 的 Py Torch 实现。
亚当的更新是,
其中、和是标量超级参数。和是一阶和二阶时刻。并且是有偏差的校正时刻。用作除以零误差的修复,但也用作对梯度方差起作用的超参数的一种形式。
假设采取的有效步骤是,这受限于、何时以及其他方面。在大多数常见情况下,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
50class Adam(GenericAdaptiveOptimizer):
params
是参数列表lr
是学习率betas
是 (,) 的元组eps
是或基于optimized_update
weight_decay
是在中WeightDecay
定义的类的实例 __init__.py
optimized_update
是一个标志,是否在添加后通过这样做来优化第二个时刻的偏差校正defaults
是组值的默认字典。当你想扩展类时,这很有用Adam
。58 def __init__(self, params,
59 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60 weight_decay: WeightDecay = WeightDecay(),
61 optimized_update: bool = True,
62 defaults: Optional[Dict[str, Any]] = None):
76 defaults = {} if defaults is None else defaults
77 defaults.update(weight_decay.defaults())
78 super().__init__(params, defaults, lr, betas, eps)
79
80 self.weight_decay = weight_decay
81 self.optimized_update = optimized_update
83 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
这是优化器对参数采取的步骤数,
93 state['step'] = 0
梯度的指数移动平均线,
95 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
梯度平方值的指数移动平均线,
97 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
99 def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
获取和
109 beta1, beta2 = group['betas']
获取和
112 m, v = state['exp_avg'], state['exp_avg_sq']
就地计算
116 m.mul_(beta1).add_(grad, alpha=1 - beta1)
就地计算
119 v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121 return m, v
123 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131 return group['lr']
state
是参数(张量)的优化器状态group
存储参数组的优化程序属性param
是参数张量m
并且v
是未校正的第一和第二时刻,以及.这计算出以下内容
由于、和是标量,其他是张量,因此我们将此计算修改为优化计算。
wher e 是我们应该指定为超参数的内容。
133 def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134 m: torch.Tensor, v: torch.Tensor):
获取和
166 beta1, beta2 = group['betas']
偏差校正术语,
168 bias_correction1 = 1 - beta1 ** state['step']
偏差校正术语,
170 bias_correction2 = 1 - beta2 ** state['step']
获取学习率
173 lr = self.get_lr(state, group)
是否优化计算
176 if self.optimized_update:
178 denominator = v.sqrt().add_(group['eps'])
180 step_size = lr * math.sqrt(bias_correction2) / bias_correction1
183 param.data.addcdiv_(m, denominator, value=-step_size)
无需优化的计算
185 else:
187 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
189 step_size = lr / bias_correction1
192 param.data.addcdiv_(m, denominator, value=-step_size)
194 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
计算体重衰减
205 grad = self.weight_decay(param, grad, group)
获取和
208 m, v = self.get_mv(state, group, grad)
增加优化器步数
211 state['step'] += 1
执行 Adam 更新
214 self.adam_update(state, group, param, m, v)