这是《零:训练一万亿个参数模型的内存优化》一文中介绍的零 DP 的实现,
它将优化器状态、梯度和参数的分片保存到多个设备/节点中。它减少了原始模型的内存消耗,其中是参数的数量,是分片的数量,是每个参数的优化器字节数。是假设精度为 16 位的参数和梯度存储器;即每个参数和梯度为 2 个字节。对于 Adam 优化器,因为它维护参数的副本,在 fp32 中每个参数两个时刻。
零 DP 的通信量为。比较而言,数据并行训练的通信量为。
尽管它被命名了Zero3
,但我们只实现了其中的零 DP 部分,没有实现针对剩余内存消耗的 Zero-R 内存优化。Out 实现仅支持训练一部分参数。
此实施的灵感来自公平规模的财务安全发展计划。
32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn
40class Zero3Layer(nn.Module):
每个分片都将参数保存在chunk
列表中。用于chunk[0]
可训练的参数,chunk[1]
用于固定参数。
49 chunk: List[nn.Parameter]
这是chunk
列表中区块的大小。
51 chunk_size: List[int]
第一个区块用于可训练的参数。
53 TRAINING_PARAMS_IDX = 0
这是分为可训练参数和固定参数的列表的参数列表。
56 param_refs: List[List[nn.Parameter]]
CUDA 流到精选参数
59 fetch_stream: Optional[torch.cuda.Stream]
用于备份/累积梯度的 CUDA 流
61 backup_stream: Optional[torch.cuda.Stream]
此图层之前的图层列表
63 prev_layer: List['Zero3Layer']
紧随此图层之后的图层列表
65 next_layer: List['Zero3Layer']
当前层的位置;用于调试日志
67 layer_idx: int
参数是否已获取
70 is_fetched: bool
该层的设备
73 device: torch.device
图层的数据类型
75 dtype: torch.dtype
要封装的模块
77 module: nn.Module
分片数据的节点/设备数量
79 world_size: int
module
要封装的模块。rank
当前节点的等级。world_size
分片数据的节点/设备数量。device
层的设备。dtype
图层的数据类型。81 def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
89 super().__init__()
初始化属性
92 self.device = device
93 self.dtype = dtype
94 self.module = module
95 self.prev_layer = []
96 self.next_layer = []
97 self.is_fetched = False
98 self.world_size = world_size
99 self.layer_idx = -1
100 self.fetch_stream = None
101 self.backup_stream = None
102
103 with torch.no_grad():
收集图层的所有参数
105 all_param_refs = [p for p in self.parameters()]
存储参数的形状,因为我们稍后需要它来重建它们
108 for p in all_param_refs:
109 p._orig_shape = p.shape
所有参数都应具有相同的类型
112 for p in all_param_refs:
113 assert p.dtype == dtype, "All parameters should have same dtype"
将参数分为可训练和固定
116 self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117 [p for p in all_param_refs if not p.requires_grad]]
118 del all_param_refs
该rank = 0
节点将计算每个设备/节点应存储的大小,并相应地分配参数。
122 if rank == 0:
合并和填充可训练 (merged_params[0]
) 和 fixed (merged_params[1]
) 参数
124 merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]
计算可训练参数和固定参数的区块大小
126 self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]
广播尺寸
128 dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129 else:
创建一个空张量来接收大小
131 chunk_size = torch.tensor([0, 0], device=device)
收到尺码
133 dist.broadcast(chunk_size, src=0)
134 self.chunk_size = chunk_size.tolist()
为要存储在当前设备/节点中的可训练 (self.chunk[0]
self.chunk[1]
) 和 fixed () 参数创建参数
138 self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139 for i, s in enumerate(self.chunk_size)]
一个空张量,用于接收可训练参数和固定参数的组合
142 chunk = self._empty((sum(self.chunk_size),))
143
144 if rank == 0:
连接可训练参数和固定参数
146 all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147 del merged_params
将它们分散到所有节点/设备
150 dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151 del all_params
152 else:
接收参数
154 dist.scatter(chunk)
收集区块数据
157 chunk = chunk.split(self.chunk_size)
158 for i, c in enumerate(chunk):
159 self.chunk[i].data[:] = c
160 del chunk
清理普通参数
163 self._cleanup_params()
添加一个向后钩子。当计算相对于模块的梯度时,会调用该函数。
166 self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook) # type: ignore
world_size
。168 def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:
参数总数
173 size = sum(p.shape.numel() for p in params)
如果它不能被整除world_size
,请填充它
176 if size % self.world_size != 0:
177 padding_fixed = self.world_size - (size % self.world_size)
否则,无需填充
179 else:
180 padding_fixed = 0
创建一个空的填充张量
182 padding = self._empty((padding_fixed,))
连接所有参数并填充它
184 return torch.cat([p.view(-1) for p in params] + [padding], dim=0)
186 def get_trainable_chunk(self) -> List[nn.Parameter]:
如果没有可训练的参数,则返回空列表
193 if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194 return []
将可训练区块作为列表返回
197 return [self.chunk[self.TRAINING_PARAMS_IDX]]
199 def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
203 return torch.empty(shape, device=self.device, dtype=self.dtype)
205 @torch.no_grad()
206 def _cleanup_params(self):
设置标志以指示未读取参数
214 self.is_fetched = False
遍历所有参数
217 for ps in self.param_refs:
218 for p in ps:
在进行任何新操作之前,请等待对参数的操作完成
220 p.data.record_stream(torch.cuda.current_stream())
检查以确保该参数不与其他任何内容共享存储
222 assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."
226 p.data.storage().resize_(0) # This is what actually clears the memory
确保参数没有梯度数据
228 assert p.grad is None, 'Gradients should be None'
230 @torch.no_grad()
231 def fetch_params(self):
已获取 Skip
239 if self.is_fetched:
240 return
设置旗帜
243 self.is_fetched = True
如果没有要获取或共享的内容,请跳过。
246 if sum(self.chunk_size) == 0:
247 return
fetch_stream
使用从所有分片中获取参数
250 with torch.cuda.stream(self.fetch_stream):
创建一个空张量来接收参数
252 buffer = self._empty((self.world_size * sum(self.chunk_size),))
将连续缓冲区拆分为节点数。这些拆分是 “缓冲区” 的视图。
254 buffers = list(buffer.split(sum(self.chunk_size)))
连接可训练和固定区块
257 chunk = torch.cat(self.chunk, dim=0)
从所有节点/设备收集参数
260 dist.all_gather(buffers, chunk)
将收集的参数拆分为可训练的和固定的区块
263 params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)
等待收集操作完成,然后清除对缓冲区的引用
265 buffer.record_stream(self.fetch_stream)
266 for b in buffers:
267 b.record_stream(self.fetch_stream)
268 buffer.record_stream(self.fetch_stream)
269 del buffer
270 del buffers
将可训练和固定参数重塑为连续张量
273 params = [p.reshape(-1) for p in params]
收集单个参数张量
276 for cont, ps in zip(params, self.param_refs):
如果没有参数,请跳过
278 if not ps:
279 continue
连续张量的偏移量
282 offset = 0
遍历模型参数并分配来自连续张量的值
284 for p in ps:
原始参数形状
286 shape = p._orig_shape # type: ignore[attr-defined]
更改参数的存储大小。这是我们清理参数时设置的。
288 p.data.storage().resize_(shape.numel())
从连续张量中分配值
290 p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)
等待操作完成后才能执行其他操作
292 p.data.record_stream(self.fetch_stream)
更新偏移量
294 offset += shape.numel()
等待操作完成后才能执行其他操作
297 cont.record_stream(self.fetch_stream)
300 del params
302 def forward(self, *args, **kwargs):
获取当前节点的所有参数。这被前一层调用,所以这个调用只是为了确保参数被抓取。
309 self.fetch_params()
等待参数提取完成。
312 torch.cuda.current_stream().wait_stream(self.fetch_stream)
开始获取后续层的参数,以便它们将获取当前层进行计算的参数。
316 for layer in self.next_layer:
317 layer.fetch_params()
启用了 autograd,则向当前层的参数添加向后挂钩。
320 if torch.is_grad_enabled():
321 self._add_backward_hooks()
计算当前图层的输出
324 res = self.module(*args, **kwargs)
330 if not torch.is_grad_enabled() or self.next_layer:
331 self._cleanup_params()
332
333 return res
335 def _add_backward_hooks(self):
添加的向后钩子数量
341 self._backward_hook_handles = 0
循环浏览当前图层的可训练参数
344 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
确保尚未添加挂钩
346 assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'
expand_as
用于创建我们可以拦截的 autograd 步骤
348 p_tmp = p.expand_as(p)
获取一个手柄来添加向后钩。这篇博客讨论grad_acc
了.
351 grad_acc = p_tmp.grad_fn.next_functions[0][0]
添加向后挂钩
353 handle = grad_acc.register_hook(
354 functools.partial(self._post_backward_hook, p))
保留对手柄的引用
356 p._hook_handle = handle
增加添加的钩子数量
358 self._backward_hook_handles += 1
360 def _backward_event(self):
减少钩子计数器
368 self._backward_hook_handles -= 1
如果所有的钩子(包括模块钩子)都被调用了,那么我们可以备份渐变并清理参数。
372 if self._backward_hook_handles == -1:
373 self._backup_grads()
374 self._cleanup_params()
开始获取前一图层的参数,因为 autograd 接下来将处理该图层的渐变。
377 for layer in self.prev_layer:
378 layer.fetch_params()
380 def _post_backward_hook(self, p: nn.Parameter, *args):
从参数中移除句柄
385 p._hook_handle.remove() # type: ignore[attr-defined]
386 delattr(p, "_hook_handle")
处理向后事件
389 self._backward_event()
391 def _backward_hook(self, *args, **kwargs):
处理向后事件
396 self._backward_event()
上一层将开始计算梯度。我们需要确保它已经完成了参数的获取。
399 torch.cuda.current_stream().wait_stream(self.fetch_stream)
402 return None
404 @torch.no_grad()
405 def _backup_grads(self):
如果没有可训练的参数,则跳过
410 if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411 return
使用备份流备份渐变
414 with torch.cuda.stream(self.backup_stream):
用于存储渐变的缓冲区
416 buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))
将连续缓冲区拆分为多个节点。这些拆分是 “缓冲区” 的视图。
418 buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))
连续缓冲区的偏移量
421 offset = 0
遍历可训练的参数
423 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
收集渐变
425 shape = p._orig_shape # type: ignore[attr-defined]
426 buffer[offset: offset + shape.numel()] = p.grad.view(-1)
更新偏移量
428 offset += shape.numel()
清理渐变
430 p.grad = None
空张量累积当前分片的梯度
433 grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))
累积每个分片的梯度。它将缓冲区分散到节点上,每个节点累积(减少)它收到的张量。
436 dist.reduce_scatter(grad, buffers)
等待操作完成,然后清除对缓冲区的引用
439 for b in buffers:
440 b.record_stream(self.fetch_stream)
441 buffer.record_stream(self.fetch_stream)
442 del buffer
443 del buffers
设置分块渐变。这就是优化器所看到的。
446 self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447 del grad
Zero3Layer
层的顺序模块450class Zero3Sequential(nn.Module):
modules
Zero3Layer
图层列表454 def __init__(self, modules: List[Zero3Layer]):
458 super().__init__()
用于获取参数的 CUDA 流
461 self.fetch_stream = torch.cuda.Stream()
用于备份(累积)梯度的 CUDA 流
463 self.backup_stream = torch.cuda.Stream()
为每个层设置流以及前面和后面的Zero3Layer
层
466 for i in range(len(modules)):
设置图层索引
468 modules[i].layer_idx = i
设置直播
470 modules[i].fetch_stream = self.fetch_stream
471 modules[i].backup_stream = self.backup_stream
设置后续图层
473 if i + 1 < len(modules):
474 modules[i].next_layer.append(modules[i + 1])
设置前面的图层
476 if i - 1 >= 0:
477 modules[i].prev_layer.append(modules[i - 1])
存储模块清单
480 self.module_list = nn.ModuleList(modules)
482 def get_trainable_chunk(self):
返回每层可训练区块的列表
484 return sum([m.get_trainable_chunk() for m in self.module_list], [])
486 def forward(self, x: torch.Tensor):
确保渐变备份已完成
488 torch.cuda.current_stream().wait_stream(self.backup_stream)
向前传球
491 for m in self.module_list:
492 x = m(x)
495 return x