Optimizer Implementations

This MNIST example uses these optimizers.

Generic Adaptive Optimizer Base class and Weight Decay

This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.

We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.

Here are some concepts on PyTorch optimizers:

Parameter groups

PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.

In most common cases there will be only one group. This is when you initialize your optimizer with,


You can define multiple parameter groups when initializing the optimizer:

Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])

Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.

You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups . Most learning rate schedule implementations I've come across do access this and change 'lr'.


Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state . This is where the optimizer maintains things like exponential averages.

63from typing import Dict, Tuple, Any
65import torch
66from torch import nn
67from torch.optim.optimizer import Optimizer

Base class for Adam and extensions

70class GenericAdaptiveOptimizer(Optimizer):


  • params is the collection of parameters or set of parameter groups.
  • defaults a dictionary of default hyper-parameters
  • lr is the learning rate,
  • betas is the tuple
  • eps is
75    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):

Check the hyper-parameters

87        if not 0.0 <= lr:
88            raise ValueError(f"Invalid learning rate: {lr}")
89        if not 0.0 <= eps:
90            raise ValueError(f"Invalid epsilon value: {eps}")
91        if not 0.0 <= betas[0] < 1.0:
92            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
93        if not 0.0 <= betas[1] < 1.0:
94            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

Add the hyper-parameters to the defaults

97        defaults.update(dict(lr=lr, betas=betas, eps=eps))

Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters

100        super().__init__(params, defaults)

Initialize state for a given parameter tensor

This should be overridden with code to initialize state for parameters param . group is the parameter group dictionary to which param belongs.

102    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
109        pass

Take optimizer step on a parameter tensor

This should be overridden and take the optimization step on param tensor , where grad is the gradient for that parameter, , state is the optimizer state dictionary for that parameter, and group is the parameter group dictionary param belongs to.

111    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
120        pass

Optimizer step

We have created a template method that does the common stuff every Adam based optimizer needs.

122    @torch.no_grad()
123    def step(self, closure=None):

Calculate loss.

🤔 I'm not sure when you need this. I guess it's if you define a function that calculates the loss, does loss.backward and return the loss, instead of calling it on your own you could pass it to optimizer.step . 🤷‍♂️

134        loss = None
135        if closure is not None:
136            with torch.enable_grad():
137                loss = closure()

Iterate through the parameter groups

140        for group in self.param_groups:

Iterate through the parameters in the parameter group

142            for param in group['params']:

Skip if the parameter has no gradient

144                if param.grad is None:
145                    continue

Get the gradient tensor

147                grad = param.grad.data

We don't handle sparse gradients

149                if grad.is_sparse:
150                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
151                                       ' please consider SparseAdam instead')

Get the state for the parameter

154                state = self.state[param]

Initialize the state if state is uninitialized

157                if len(state) == 0:
158                    self.init_state(state, group, param)

Take the optimization step on the parameter

161                self.step_param(state, group, grad, param)

Return the loss, calculated from closure

164        return loss

L2 Weight decay

167class WeightDecay:

Initialize weight decay

  • weight_decay is the decay coefficient
  • weight_decouple is a flag indicating whether to add the weight decay to the gradient or directly decay from the parameter. If added to the gradient it will go through the normal optimizer update.
  • absolute this flag indicates whether the weight decay coefficient is absolute. This is applicable when the decay is performed directly on the parameter. If this is false the actual decay is weight_decay
  • learning_rate .
172    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):

Check hyper-parameters

185        if not 0.0 <= weight_decay:
186            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
188        self.absolute = absolute
189        self.weight_decouple = weight_decouple
190        self.weight_decay = weight_decay

Return defaults for parameter groups

192    def defaults(self):
196        return dict(weight_decay=self.weight_decay)

Perform weight decay and return the gradient

198    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):

If we are doing the decay on the parameter directly

204        if self.weight_decouple:

If the weight decay coefficient is absolute

206            if self.absolute:
207                param.data.mul_(1.0 - group['weight_decay'])


209            else:
210                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])

Return the unmodified gradient

212            return grad
213        else:
214            if group['weight_decay'] != 0:

Add the weight decay to the gradient and return the modified gradient

216                return grad.add(param.data, alpha=group['weight_decay'])
217            else:
218                return grad