Adam Optimizer for Half Precision Training

10from typing import Dict, Tuple, Optional, Any
11
12import torch
13from torch import nn
14from torch.optim import Optimizer
15from torch.cuda.amp import grad_scaler
16from collections import defaultdict, abc
17
18from labml_nn.optimizers import WeightDecay
19from labml_nn.optimizers.adam import Adam

Adam Optimizer for Half Precision Training

We extend Adam Optimizer but use FP32 to store gradients and moments.

22class AdamFP16(Adam):
29    def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30                 weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31                 defaults: Optional[Dict[str, Any]] = None):

Parameter to store 32 bit gradients. This get populated by the GradScaler defined below.

33        self.grad_fp32 = {}

Call the Adam Optimizer initializer

35        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor

All the state tensors use FP32.

37    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

This is the number of optimizer steps taken on the parameter,

49        state['step'] = 0

Exponential moving average of gradients,

51        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)

Exponential moving average of squared gradient values,

53        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)

Maintain a FP32 copy of the parameters

55        state['fp32_copy'] = param.to(torch.float)

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor for the parameter
  • param is the parameter tensor
57    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Get the FP32 parameters

68        param_fp32 = state['fp32_copy']

Get the FP32 gradients if available

70        grad_fp32 = self.grad_fp32.get(param, None)
71        if grad_fp32 is not None:
72            del self.grad_fp32[param]
73            grad = grad_fp32
74        else:

Otherwise, convert the gradients to FP32

76            grad = grad.to(torch.float)

Calculate weight decay

79        grad = self.weight_decay(param_fp32, grad, group)

Get and

82        m, v = self.get_mv(state, group, grad)

Increment the number of optimizer steps

85        state['step'] += 1

Perform Adam update

88        self.adam_update(state, group, param_fp32, m, v)

Set the parameters

91        param.data = param_fp32.to(param.dtype)

Gradient Scaler with half precision gradients

We extend PyTorch gradient scaler to use FP32 gradients.

94class GradScalerFP16(grad_scaler.GradScaler):
101    def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
102                        allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
103        per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
104        per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
105
106        per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
107
108        with torch.no_grad():

Loop through parameters

110            for group in optimizer.param_groups:
111                for param in group["params"]:

Skip non-trainable parameters

113                    if param.grad is None:
114                        continue

Not implemented for sparse tensors

116                    if param.grad.is_sparse:
117                        raise NotImplementedError

If we are using the AdamFP16 optimizer set optimizer.grad_fp32[param] to the FP32 gradients

120                    if isinstance(optimizer, AdamFP16):
121                        grad = param.grad.to(torch.float)
122                        optimizer.grad_fp32[param] = grad

Otherwise, do not convert the gradients to FP32

124                    else:
125                        grad = param.grad
126
127                    per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)

Unscale all the gradients

130            for device, per_dtype_grads in per_device_and_dtype_grads.items():
131                for grads in per_dtype_grads.values():
132                    torch._amp_foreach_non_finite_check_and_unscale_(grads,
133                                                                     per_device_found_inf.get(device),
134                                                                     per_device_inv_scale.get(device))

136        return per_device_found_inf._per_device_tensors