This is a PyTorch implementation of Sophia-G from paper Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training. Official implementation is available at Liuhong99/Sophia.
Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost pre-conditioner.
Sophia keeps diagonal Hessian estimates with EMA across iterations. The diagonal Hessian is calculated every steps.
Sophia uses EMA of gradients , only considers positive entries of the diagonal Hessian and does per-coordinate clipping to the update.
where is a very small value to prevent division by .
where are the inputs, is the batch size (number of inputs/tokens), is cross entropy loss, and are sampled from the logits .
Note that this hessian estimate is always positive and therefore we can replace with .
Sophia with Gauss-Newton-Bartlett (GNB) estimator is Sophia-G
Here is an experiment that uses Sophia-G to train a transformer.
54from typing import Dict, Any, Tuple, Optional
55
56import torch
57from torch import nn
58
59from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
We extend the class GenericAdaptiveOptimizer
defined in __init__.py
to implement the Sophia optimizer.
62class Sophia(GenericAdaptiveOptimizer):
params
is the list of parameters lr
is the maximum learning rate betas
is a tuple of (, ) eps
is pho
is weight_decay
is an instance of class WeightDecay
defined in __init__.py
defaults
is a dictionary of default for group values. This is useful when you want to extend the class Adam
.70 def __init__(self, params,
71 lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,
72 rho: float = 0.03,
73 weight_decay: WeightDecay = WeightDecay(),
74 defaults: Optional[Dict[str, Any]] = None):
87 defaults = {} if defaults is None else defaults
88 defaults.update(weight_decay.defaults())
89 defaults.update(dict(rho=rho))
90 super().__init__(params, defaults, lr, betas, eps)
91
92 self.weight_decay = weight_decay
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group param
is the parameter tensor 94 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
This is the number of optimizer steps taken on the parameter,
104 state['step'] = 0
Exponential moving average of gradients,
106 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
Exponential moving average of Hessian diagonal,
108 state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
n_tokens_training_batch
is the number of tokens/inputs in the batch 110 def update_hessian(self, n_tokens_training_batch):
Iterate through parameter groups
123 for group in self.param_groups:
125 _, beta2 = group['betas']
Iterate through parameters
127 for p in group['params']:
Skip parameters without gradients
129 if p.grad is None:
130 continue
Get optimizer state
133 state = self.state[p]
Initialize state if empty
136 if len(state) == 0:
137 self.init_state(state, group, p)
145 state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter param
is the parameter tensor We do the following parameter update,
147 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
164 grad = self.weight_decay(param, grad, group)
Get and
167 beta1, beta2 = group['betas']
Get
169 rho = group['rho']
Get and
172 m, hessian = state['exp_avg'], state['hessian']
In-place calculation of
176 m.mul_(beta1).add_(grad, alpha=1 - beta1)
Increment the number of optimizer steps
179 state['step'] += 1
Get maximum learning rate
182 lr = group['lr']
185 eta = lr / rho
188 ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)
191 param.data.add_(ratio, alpha=-eta)