10from typing import Dict, Tuple, Optional, Any
11
12import torch
13from torch import nn
14from torch.optim import Optimizer
15from torch.cuda.amp import grad_scaler
16from collections import defaultdict, abc
17
18from labml_nn.optimizers import WeightDecay
19from labml_nn.optimizers.adam import Adam
We extend Adam Optimizer but use FP32 to store gradients and moments.
22class AdamFP16(Adam):
29 def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30 weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31 defaults: Optional[Dict[str, Any]] = None):
Parameter to store 32 bit gradients. This get populated by the GradScaler
defined below.
33 self.grad_fp32 = {}
Call the Adam Optimizer initializer
35 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group param
is the parameter tensor All the state tensors use FP32.
37 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
This is the number of optimizer steps taken on the parameter,
49 state['step'] = 0
Exponential moving average of gradients,
51 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
Exponential moving average of squared gradient values,
53 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
Maintain a FP32 copy of the parameters
55 state['fp32_copy'] = param.to(torch.float)
state
is the optimizer state of the parameter (tensor) group
stores optimizer attributes of the parameter group grad
is the current gradient tensor for the parameter param
is the parameter tensor 57 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Get the FP32 parameters
68 param_fp32 = state['fp32_copy']
Get the FP32 gradients if available
70 grad_fp32 = self.grad_fp32.get(param, None)
71 if grad_fp32 is not None:
72 del self.grad_fp32[param]
73 grad = grad_fp32
74 else:
Otherwise, convert the gradients to FP32
76 grad = grad.to(torch.float)
Calculate weight decay
79 grad = self.weight_decay(param_fp32, grad, group)
Get and
82 m, v = self.get_mv(state, group, grad)
Increment the number of optimizer steps
85 state['step'] += 1
Perform Adam update
88 self.adam_update(state, group, param_fp32, m, v)
Set the parameters
91 param.data = param_fp32.to(param.dtype)
We extend PyTorch gradient scaler to use FP32 gradients.
94class GradScalerFP16(grad_scaler.GradScaler):
101 def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
102 allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
103 per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
104 per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
105
106 per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
107
108 with torch.no_grad():
Loop through parameters
110 for group in optimizer.param_groups:
111 for param in group["params"]:
Skip non-trainable parameters
113 if param.grad is None:
114 continue
Not implemented for sparse tensors
116 if param.grad.is_sparse:
117 raise NotImplementedError
If we are using the AdamFP16
optimizer set optimizer.grad_fp32[param]
to the FP32 gradients
120 if isinstance(optimizer, AdamFP16):
121 grad = param.grad.to(torch.float)
122 optimizer.grad_fp32[param] = grad
Otherwise, do not convert the gradients to FP32
124 else:
125 grad = param.grad
126
127 per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
Unscale all the gradients
130 for device, per_dtype_grads in per_device_and_dtype_grads.items():
131 for grads in per_dtype_grads.values():
132 torch._amp_foreach_non_finite_check_and_unscale_(grads,
133 per_device_found_inf.get(device),
134 per_device_inv_scale.get(device))
136 return per_device_found_inf._per_device_tensors