This MNIST example uses these optimizers.
This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.
We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.
Here are some concepts on PyTorch optimizers:
PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.
In most common cases there will be only one group. This is when you initialize your optimizer with,
Optimizer(model.parameters())
You can define multiple parameter groups when initializing the optimizer:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups
. Most learning rate schedule implementations I've come across do access this and change 'lr'.
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state
. This is where the optimizer maintains things like exponential averages.
63from typing import Dict, Tuple, Any
64
65import torch
66from torch import nn
67from torch.optim.optimizer import Optimizer
70class GenericAdaptiveOptimizer(Optimizer):
75 def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):
检查超参数
87 if not 0.0 <= lr:
88 raise ValueError(f"Invalid learning rate: {lr}")
89 if not 0.0 <= eps:
90 raise ValueError(f"Invalid epsilon value: {eps}")
91 if not 0.0 <= betas[0] < 1.0:
92 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
93 if not 0.0 <= betas[1] < 1.0:
94 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
将超参数添加到默认值
97 defaults.update(dict(lr=lr, betas=betas, eps=eps))
初始化 PyTorch 优化器。这将使用默认的超参数创建参数组
100 super().__init__(params, defaults)
102 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
109 pass
111 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
120 pass
122 @torch.no_grad()
123 def step(self, closure=None):
134 loss = None
135 if closure is not None:
136 with torch.enable_grad():
137 loss = closure()
遍历参数组
140 for group in self.param_groups:
遍历参数组中的参数
142 for param in group['params']:
如果参数没有渐变,则跳过
144 if param.grad is None:
145 continue
获取梯度张量
147 grad = param.grad.data
我们不处理稀疏渐变
149 if grad.is_sparse:
150 raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
151 ' please consider SparseAdam instead')
获取参数的状态
154 state = self.state[param]
如果状态未初始化,则初始化状态
157 if len(state) == 0:
158 self.init_state(state, group, param)
对参数采取优化步骤
161 self.step_param(state, group, grad, param)
返回从闭包计算得出的损失
164 return loss
167class WeightDecay:
weight_decay
是衰减系数weight_decouple
是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。absolute
此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为weight_decay
learning_rate
。172 def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):
检查超参数
185 if not 0.0 <= weight_decay:
186 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
187
188 self.absolute = absolute
189 self.weight_decouple = weight_decouple
190 self.weight_decay = weight_decay
返回参数组的默认值
192 def defaults(self):
196 return dict(weight_decay=self.weight_decay)
198 def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
如果我们直接对参数进行衰减
204 if self.weight_decouple:
如果权重衰减系数为绝对值
206 if self.absolute:
207 param.data.mul_(1.0 - group['weight_decay'])
否则,
209 else:
210 param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
返回未修改的渐变
212 return grad
213 else:
214 if group['weight_decay'] != 0:
将权重衰减添加到渐变并返回修改后的渐变
216 return grad.add(param.data, alpha=group['weight_decay'])
217 else:
218 return grad