Sophia Optimizer

This is a PyTorch implementation of Sophia-G from paper Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training. Official implementation is available at Liuhong99/Sophia.

Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost pre-conditioner.

Sophia keeps diagonal Hessian estimates with EMA across iterations. The diagonal Hessian is calculated every steps.

Sophia uses EMA of gradients , only considers positive entries of the diagonal Hessian and does per-coordinate clipping to the update.

where is a very small value to prevent division by .

Gauss-Newton-Bartlett (GNB) estimator

where are the inputs, is the batch size (number of inputs/tokens), is cross entropy loss, and are sampled from the logits .

Note that this hessian estimate is always positive and therefore we can replace with .

Sophia with Gauss-Newton-Bartlett (GNB) estimator is Sophia-G

Here is an experiment that uses Sophia-G to train a transformer.

54from typing import Dict, Any, Tuple, Optional
55
56import torch
57from torch import nn
58
59from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

Sophia-G Optimizer

We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Sophia optimizer.

62class Sophia(GenericAdaptiveOptimizer):

Initialize the optimizer

  • params is the list of parameters
  • lr is the maximum learning rate
  • betas is a tuple of (, )
  • eps is
  • pho is
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .
70    def __init__(self, params,
71                 lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,
72                 rho: float = 0.03,
73                 weight_decay: WeightDecay = WeightDecay(),
74                 defaults: Optional[Dict[str, Any]] = None):
87        defaults = {} if defaults is None else defaults
88        defaults.update(weight_decay.defaults())
89        defaults.update(dict(rho=rho))
90        super().__init__(params, defaults, lr, betas, eps)
91
92        self.weight_decay = weight_decay

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor
94    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

This is the number of optimizer steps taken on the parameter,

104        state['step'] = 0

Exponential moving average of gradients,

106        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Exponential moving average of Hessian diagonal,

108        state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Update the EMA of Hessian diagonal

  • n_tokens_training_batch is the number of tokens/inputs in the batch
110    def update_hessian(self, n_tokens_training_batch):

Iterate through parameter groups

123        for group in self.param_groups:

125            _, beta2 = group['betas']

Iterate through parameters

127            for p in group['params']:

Skip parameters without gradients

129                if p.grad is None:
130                    continue

Get optimizer state

133                state = self.state[p]

Initialize state if empty

136                if len(state) == 0:
137                    self.init_state(state, group, p)

Update EMA Hessian diagonal

145                state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
  • param is the parameter tensor

We do the following parameter update,

147    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Calculate weight decay

164        grad = self.weight_decay(param, grad, group)

Get and

167        beta1, beta2 = group['betas']

Get

169        rho = group['rho']

Get and

172        m, hessian = state['exp_avg'], state['hessian']

In-place calculation of

176        m.mul_(beta1).add_(grad, alpha=1 - beta1)

Increment the number of optimizer steps

179        state['step'] += 1

Get maximum learning rate

182        lr = group['lr']

185        eta = lr / rho

188        ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)

191        param.data.add_(ratio, alpha=-eta)