This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.
Adam update is,
where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
Effective step taken assuming is, This is bounded by, when and otherwise. And in most common scenarios,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
We extend the class GenericAdaptiveOptimizer
defined in __init__.py
to implement the Adam optimizer.
50class Adam(GenericAdaptiveOptimizer):
params
is the list of parameters lr
is the learning rate betas
is a tuple of (, ) eps
is or based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
optimized_update
is a flag whether to optimize the bias correction of the second moment by doing it after adding defaults
is a dictionary of default for group values. This is useful when you want to extend the class Adam
.58 def __init__(self, params,
59 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60 weight_decay: WeightDecay = WeightDecay(),
61 optimized_update: bool = True,
62 defaults: Optional[Dict[str, Any]] = None):
76 defaults = {} if defaults is None else defaults
77 defaults.update(weight_decay.defaults())
78 super().__init__(params, defaults, lr, betas, eps)
79
80 self.weight_decay = weight_decay
81 self.optimized_update = optimized_update
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group param
is the parameter tensor 83 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
This is the number of optimizer steps taken on the parameter,
93 state['step'] = 0
Exponential moving average of gradients,
95 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
Exponential moving average of squared gradient values,
97 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter 99 def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
Get and
109 beta1, beta2 = group['betas']
Get and
112 m, v = state['exp_avg'], state['exp_avg_sq']
In-place calculation of
116 m.mul_(beta1).add_(grad, alpha=1 - beta1)
In-place calculation of
119 v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121 return m, v
This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, .
123 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131 return group['lr']
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group param
is the parameter tensor m
and v
are the uncorrected first and second moments and .This computes the following
Since , , and are scalars and others are tensors we modify this calculation to optimize the computation.
where is what we should specify as the hyper-parameter.
133 def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134 m: torch.Tensor, v: torch.Tensor):
Get and
166 beta1, beta2 = group['betas']
Bias correction term for ,
168 bias_correction1 = 1 - beta1 ** state['step']
Bias correction term for ,
170 bias_correction2 = 1 - beta2 ** state['step']
Get learning rate
173 lr = self.get_lr(state, group)
Whether to optimize the computation
176 if self.optimized_update:
178 denominator = v.sqrt().add_(group['eps'])
180 step_size = lr * math.sqrt(bias_correction2) / bias_correction1
183 param.data.addcdiv_(m, denominator, value=-step_size)
Computation without optimization
185 else:
187 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
189 step_size = lr / bias_correction1
192 param.data.addcdiv_(m, denominator, value=-step_size)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter param
is the parameter tensor 194 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
205 grad = self.weight_decay(param, grad, group)
Get and
208 m, v = self.get_mv(state, group, grad)
Increment the number of optimizer steps
211 state['step'] += 1
Perform Adam update
214 self.adam_update(state, group, param, m, v)