Adam Optimizer 带热身

这扩展了 AmsGrad 优化器并增加了预热阶段。

12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad

Adam Optimizer 带热身

这个类是从中定义的 AmsGrad 优化器扩展而来的amsgrad.py

18class AdamWarmup(AMSGrad):

初始化优化器

  • params 是参数列表
  • lr 是学习率
  • betas 是 (,) 的元组
  • eps基于optimized_update
  • weight_decay 是在中WeightDecay 定义的类的实例 __init__.py
  • “optimized_update” 是一个标志,在添加后是否要优化第二个时刻的偏差校正
  • amsgrad 是一个标志,指示是使用 AmsGrad 还是回退到普通的 Adam
  • warmup 预热步数
  • defaults 是组值的默认字典。当你想扩展类时,这很有用AdamWarmup
24    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
25                 weight_decay: WeightDecay = WeightDecay(),
26                 optimized_update: bool = True,
27                 amsgrad=False, warmup=0, defaults=None):
44        defaults = {} if defaults is None else defaults
45        defaults.update(dict(warmup=warmup))
46        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

获取学习率

其中是预热步骤的数量。

48    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):

如果我们处于热身阶段

56        if group['warmup'] > state['step']:

学习率从线性增加

58            return 1e-8 + state['step'] * group['lr'] / group['warmup']
59        else:

持续的学习速率

61            return group['lr']