零 DP 内存优化

这是《零:训练一万亿个参数模型的内存优化》一文中介绍的零 DP 的实现

它将优化器状态、梯度和参数的分片保存到多个设备/节点中。它减少了原始模型的内存消耗,其中是参数的数量,是分片的数量,是每个参数的优化器字节数。是假设精度为 16 位的参数和梯度存储器;即每个参数和梯度为 2 个字节。对于 Adam 优化器,因为它维护参数的副本,在 fp32 中每个参数两个时刻。

零 DP 的通信量为。比较而言,数据并行训练的通信量为

尽管它被命名了Zero3 ,但我们只实现了其中的零 DP 部分,没有实现针对剩余内存消耗的 Zero-R 内存优化。Out 实现仅支持训练一部分参数。

此实施的灵感来自公平规模的财务安全发展计划

以下是使用零 DP 内存优化微调 GPT NeoX 的脚本

32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn

Zero3 层

模型的每一层(或几个连续层的组合)都应该包裹在这个模块中。

40class Zero3Layer(nn.Module):

每个分片都将参数保存在chunk 列表中。用于chunk[0] 可训练的参数,chunk[1] 用于固定参数。

49    chunk: List[nn.Parameter]

这是chunk 列表中区块的大小。

51    chunk_size: List[int]

第一个区块用于可训练的参数。

53    TRAINING_PARAMS_IDX = 0

这是分为可训练参数和固定参数的列表的参数列表。

56    param_refs: List[List[nn.Parameter]]

CUDA 流到精选参数

59    fetch_stream: Optional[torch.cuda.Stream]

用于备份/累积梯度的 CUDA 流

61    backup_stream: Optional[torch.cuda.Stream]

此图层之前的图层列表

63    prev_layer: List['Zero3Layer']

紧随此图层之后的图层列表

65    next_layer: List['Zero3Layer']

当前层的位置;用于调试日志

67    layer_idx: int

参数是否已获取

70    is_fetched: bool

该层的设备

73    device: torch.device

图层的数据类型

75    dtype: torch.dtype

要封装的模块

77    module: nn.Module

分片数据的节点/设备数量

79    world_size: int
  • module 要封装的模块。
  • rank 当前节点的等级。
  • world_size 分片数据的节点/设备数量。
  • device 层的设备。
  • dtype 图层的数据类型。
81    def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
89        super().__init__()

初始化属性

92        self.device = device
93        self.dtype = dtype
94        self.module = module
95        self.prev_layer = []
96        self.next_layer = []
97        self.is_fetched = False
98        self.world_size = world_size
99        self.layer_idx = -1
100        self.fetch_stream = None
101        self.backup_stream = None
102
103        with torch.no_grad():

收集图层的所有参数

105            all_param_refs = [p for p in self.parameters()]

存储参数的形状,因为我们稍后需要它来重建它们

108            for p in all_param_refs:
109                p._orig_shape = p.shape

所有参数都应具有相同的类型

112            for p in all_param_refs:
113                assert p.dtype == dtype, "All parameters should have same dtype"

将参数分为可训练和固定

116            self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117                               [p for p in all_param_refs if not p.requires_grad]]
118            del all_param_refs

rank = 0 节点将计算每个设备/节点应存储的大小,并相应地分配参数。

122            if rank == 0:

合并和填充可训练 (merged_params[0] ) 和 fixed (merged_params[1] ) 参数

124                merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]

计算可训练参数和固定参数的区块大小

126                self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]

广播尺寸

128                dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129            else:

创建一个空张量来接收大小

131                chunk_size = torch.tensor([0, 0], device=device)

收到尺码

133                dist.broadcast(chunk_size, src=0)
134                self.chunk_size = chunk_size.tolist()

为要存储在当前设备/节点中的可训练 (self.chunk[0] self.chunk[1] ) 和 fixed () 参数创建参数

138            self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139                          for i, s in enumerate(self.chunk_size)]

一个空张量,用于接收可训练参数和固定参数的组合

142            chunk = self._empty((sum(self.chunk_size),))
143
144            if rank == 0:

连接可训练参数和固定参数

146                all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147                del merged_params

将它们分散到所有节点/设备

150                dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151                del all_params
152            else:

接收参数

154                dist.scatter(chunk)

收集区块数据

157            chunk = chunk.split(self.chunk_size)
158            for i, c in enumerate(chunk):
159                self.chunk[i].data[:] = c
160            del chunk

清理普通参数

163            self._cleanup_params()

添加一个向后钩子。当计算相对于模块的梯度时,会调用该函数。

166            self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook)  # type: ignore

合并所有参数并填充它,使其可被整除world_size

168    def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:

参数总数

173        size = sum(p.shape.numel() for p in params)

如果它不能被整除world_size ,请填充它

176        if size % self.world_size != 0:
177            padding_fixed = self.world_size - (size % self.world_size)

否则,无需填充

179        else:
180            padding_fixed = 0

创建一个空的填充张量

182        padding = self._empty((padding_fixed,))

连接所有参数并填充它

184        return torch.cat([p.view(-1) for p in params] + [padding], dim=0)

获取可训练的参数块/分片。

这就是我们传递给当前节点上的优化器的内容。

186    def get_trainable_chunk(self) -> List[nn.Parameter]:

如果没有可训练的参数,则返回空列表

193        if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194            return []

将可训练区块作为列表返回

197        return [self.chunk[self.TRAINING_PARAMS_IDX]]

创建给定形状的空张量。

199    def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
203        return torch.empty(shape, device=self.device, dtype=self.dtype)

清理参数数据

这将释放层参数使用的所有内存。

205    @torch.no_grad()
206    def _cleanup_params(self):

设置标志以指示未读取参数

214        self.is_fetched = False

遍历所有参数

217        for ps in self.param_refs:
218            for p in ps:

在进行任何新操作之前,请等待对参数的操作完成

220                p.data.record_stream(torch.cuda.current_stream())

检查以确保该参数不与其他任何内容共享存储

222                assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."

将存储空间调整为。这将释放参数使用的内存。

设置p.data 不会释放内存,因为 autograd 图形会保留对它的引用。

226                p.data.storage().resize_(0)  # This is what actually clears the memory

确保参数没有梯度数据

228                assert p.grad is None, 'Gradients should be None'

从所有分片中获取参数

这将从所有节点获取所有参数数据,并在每个节点上重建参数。

230    @torch.no_grad()
231    def fetch_params(self):

已获取 Skip

239        if self.is_fetched:
240            return

设置旗帜

243        self.is_fetched = True

如果没有要获取或共享的内容,请跳过。

246        if sum(self.chunk_size) == 0:
247            return

fetch_stream 使用从所有分片中获取参数

250        with torch.cuda.stream(self.fetch_stream):

创建一个空张量来接收参数

252            buffer = self._empty((self.world_size * sum(self.chunk_size),))

将连续缓冲区拆分为节点数。这些拆分是 “缓冲区” 的视图。

254            buffers = list(buffer.split(sum(self.chunk_size)))

连接可训练和固定区块

257            chunk = torch.cat(self.chunk, dim=0)

从所有节点/设备收集参数

260            dist.all_gather(buffers, chunk)

将收集的参数拆分为可训练的和固定的区块

263            params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)

等待收集操作完成,然后清除对缓冲区的引用

265            buffer.record_stream(self.fetch_stream)
266            for b in buffers:
267                b.record_stream(self.fetch_stream)
268            buffer.record_stream(self.fetch_stream)
269            del buffer
270            del buffers

将可训练和固定参数重塑为连续张量

273            params = [p.reshape(-1) for p in params]

收集单个参数张量

276            for cont, ps in zip(params, self.param_refs):

如果没有参数,请跳过

278                if not ps:
279                    continue

连续张量的偏移量

282                offset = 0

遍历模型参数并分配来自连续张量的值

284                for p in ps:

原始参数形状

286                    shape = p._orig_shape  # type: ignore[attr-defined]

更改参数的存储大小。这是我们清理参数时设置的。

288                    p.data.storage().resize_(shape.numel())

从连续张量中分配值

290                    p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)

等待操作完成后才能执行其他操作

292                    p.data.record_stream(self.fetch_stream)

更新偏移量

294                    offset += shape.numel()

等待操作完成后才能执行其他操作

297                cont.record_stream(self.fetch_stream)

300            del params

向前传球

302    def forward(self, *args, **kwargs):

获取当前节点的所有参数。这被前一层调用,所以这个调用只是为了确保参数被抓取。

309        self.fetch_params()

等待参数提取完成。

312        torch.cuda.current_stream().wait_stream(self.fetch_stream)

开始获取后续层的参数,以便它们将获取当前层进行计算的参数。

316        for layer in self.next_layer:
317            layer.fetch_params()
如果@@

启用了 autograd,则向当前层的参数添加向后挂钩。

320        if torch.is_grad_enabled():
321            self._add_backward_hooks()

计算当前图层的输出

324        res = self.module(*args, **kwargs)

清理图层的参数。

如果启用了 autograd,并且这是网络中的最后一层,则跳过清理,因为我们需要再次获取参数才能进行反向传递。

330        if not torch.is_grad_enabled() or self.next_layer:
331            self._cleanup_params()
332
333        return res

向当前图层的参数添加向后挂钩。

335    def _add_backward_hooks(self):

添加的向后钩子数量

341        self._backward_hook_handles = 0

循环浏览当前图层的可训练参数

344        for p in self.param_refs[self.TRAINING_PARAMS_IDX]:

确保尚未添加挂钩

346            assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'

expand_as 用于创建我们可以拦截的 autograd 步骤

348            p_tmp = p.expand_as(p)

获取一个手柄来添加向后钩。这篇博客讨论grad_acc 了.

351            grad_acc = p_tmp.grad_fn.next_functions[0][0]

添加向后挂钩

353            handle = grad_acc.register_hook(
354                functools.partial(self._post_backward_hook, p))

保留对手柄的引用

356            p._hook_handle = handle

增加添加的钩子数量

358            self._backward_hook_handles += 1

处理向后事件

这被参数反向钩子和模块后向钩子调用。

360    def _backward_event(self):

减少钩子计数器

368        self._backward_hook_handles -= 1

如果所有的钩子(包括模块钩子)都被调用了,那么我们可以备份渐变并清理参数。

372        if self._backward_hook_handles == -1:
373            self._backup_grads()
374            self._cleanup_params()

开始获取前一图层的参数,因为 autograd 接下来将处理该图层的渐变。

377        for layer in self.prev_layer:
378            layer.fetch_params()

参数向后挂钩

380    def _post_backward_hook(self, p: nn.Parameter, *args):

从参数中移除句柄

385        p._hook_handle.remove()  # type: ignore[attr-defined]
386        delattr(p, "_hook_handle")

处理向后事件

389        self._backward_event()

模块向后挂钩

391    def _backward_hook(self, *args, **kwargs):

处理向后事件

396        self._backward_event()

上一层将开始计算梯度。我们需要确保它已经完成了参数的获取。

399        torch.cuda.current_stream().wait_stream(self.fetch_stream)

402        return None

备份当前图层的渐变

404    @torch.no_grad()
405    def _backup_grads(self):

如果没有可训练的参数,则跳过

410        if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411            return

使用备份流备份渐变

414        with torch.cuda.stream(self.backup_stream):

用于存储渐变的缓冲区

416            buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))

将连续缓冲区拆分为多个节点。这些拆分是 “缓冲区” 的视图。

418            buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))

连续缓冲区的偏移量

421            offset = 0

遍历可训练的参数

423            for p in self.param_refs[self.TRAINING_PARAMS_IDX]:

收集渐变

425                shape = p._orig_shape  # type: ignore[attr-defined]
426                buffer[offset: offset + shape.numel()] = p.grad.view(-1)

更新偏移量

428                offset += shape.numel()

清理渐变

430                p.grad = None

空张量累积当前分片的梯度

433            grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))

累积每个分片的梯度。它将缓冲区分散到节点上,每个节点累积(减少)它收到的张量。

436            dist.reduce_scatter(grad, buffers)

等待操作完成,然后清除对缓冲区的引用

439            for b in buffers:
440                b.record_stream(self.fetch_stream)
441            buffer.record_stream(self.fetch_stream)
442            del buffer
443            del buffers

设置分块渐变。这就是优化器所看到的。

446            self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447            del grad

Zero3Layer 层的顺序模块

450class Zero3Sequential(nn.Module):
  • modules Zero3Layer 图层列表
454    def __init__(self, modules: List[Zero3Layer]):
458        super().__init__()

用于获取参数的 CUDA 流

461        self.fetch_stream = torch.cuda.Stream()

用于备份(累积)梯度的 CUDA 流

463        self.backup_stream = torch.cuda.Stream()

为每个层设置流以及前面和后面的Zero3Layer

466        for i in range(len(modules)):

设置图层索引

468            modules[i].layer_idx = i

设置直播

470            modules[i].fetch_stream = self.fetch_stream
471            modules[i].backup_stream = self.backup_stream

设置后续图层

473            if i + 1 < len(modules):
474                modules[i].next_layer.append(modules[i + 1])

设置前面的图层

476            if i - 1 >= 0:
477                modules[i].prev_layer.append(modules[i - 1])

存储模块清单

480        self.module_list = nn.ModuleList(modules)
482    def get_trainable_chunk(self):

返回每层可训练区块的列表

484        return sum([m.get_trainable_chunk() for m in self.module_list], [])
486    def forward(self, x: torch.Tensor):

确保渐变备份已完成

488        torch.cuda.current_stream().wait_stream(self.backup_stream)

向前传球

491        for m in self.module_list:
492            x = m(x)

495        return x