ゼロ DP メモリ最適化

これは、論文「Zero: 1兆のパラメーターモデルのトレーニングに向けたメモリ最適化」で紹介されているゼロDPの実装です

オプティマイザの状態、グラデーション、パラメータの断片を複数のデバイス/ノードに保持します。これにより、メモリ消費量が元のモデルと同じになります。ここで、はパラメーターの数、はシャードの数、パラメーターごとのオプティマイザーのバイト数です。は、16 ビットの精度を前提としたパラメーターとグラデーションのメモリです。つまり、パラメーターとグラデーションごとに 2 バイトです。Adam オプティマイザー用です。これは、パラメーターのコピーと fp32 のパラメーターごとに 2 つのモーメントを保持しているためです

ゼロDPの通信量は。比較のためにデータ並行トレーニングの通信量は

.

これは名前が付けられていますがZero3 、残留メモリ消費を対象とするゼロRメモリ最適化は実装しておらず、DPがゼロの部分のみを実装しています。この実装では、パラメータのサブセットのみのトレーニングをサポートしています

この実装はフェアスケールFSDPに触発されています

ゼロDPメモリ最適化を使用してGPT NeoXを微調整するスクリプトを次に示します

32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn

ゼロ 3 レイヤー

モデルの各レイヤー(またはいくつかの連続したレイヤーの組み合わせ)をこのモジュールでラップする必要があります。

40class Zero3Layer(nn.Module):

chunk 各シャードはパラメータをリストに保持します。chunk[0] はトレーニング可能なパラメーター用で、chunk[1] 固定パラメーター用です

49    chunk: List[nn.Parameter]

chunk これはリスト内のチャンクのサイズです。

51    chunk_size: List[int]

最初のチャンクはトレーニング可能なパラメーター用です。

53    TRAINING_PARAMS_IDX = 0

これは、トレーニング可能なパラメーターと固定パラメーターとしてリストに分割されたパラメーターのリストです。

56    param_refs: List[List[nn.Parameter]]

パラメータを取得する CUDA ストリーム

59    fetch_stream: Optional[torch.cuda.Stream]

勾配をバックアップ/蓄積するためのCUDAストリーム

61    backup_stream: Optional[torch.cuda.Stream]

このレイヤーの直前のレイヤーのリスト

63    prev_layer: List['Zero3Layer']

このレイヤーの直後のレイヤーのリスト

65    next_layer: List['Zero3Layer']

現在のレイヤーの位置。これをログのデバッグに使用しました

67    layer_idx: int

パラメータが取得されているかどうか

70    is_fetched: bool

レイヤーのデバイス

73    device: torch.device

レイヤーのデータタイプ

75    dtype: torch.dtype

ラップするモジュール

77    module: nn.Module

データがシャーディングされるノード/デバイスの数

79    world_size: int
  • module ラップするモジュール。
  • rank 現在のノードのランク。
  • world_size データがシャーディングされるノード/デバイスの数。
  • device レイヤーのデバイス。
  • dtype レイヤーのデータタイプ。
81    def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
89        super().__init__()

プロパティを初期化

92        self.device = device
93        self.dtype = dtype
94        self.module = module
95        self.prev_layer = []
96        self.next_layer = []
97        self.is_fetched = False
98        self.world_size = world_size
99        self.layer_idx = -1
100        self.fetch_stream = None
101        self.backup_stream = None
102
103        with torch.no_grad():

レイヤーのすべてのパラメーターを収集します。

105            all_param_refs = [p for p in self.parameters()]

パラメータの形状を保存しておきます。後で再構築する必要があるからです。

108            for p in all_param_refs:
109                p._orig_shape = p.shape

すべてのパラメータは同じタイプでなければなりません

112            for p in all_param_refs:
113                assert p.dtype == dtype, "All parameters should have same dtype"

トレーニング可能なパラメータと固定パラメータを分離

116            self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117                               [p for p in all_param_refs if not p.requires_grad]]
118            del all_param_refs

rank = 0 ノードは、各デバイス/ノードが保存するサイズを計算し、それに応じてパラメータを分散します。

122            if rank == 0:

トレーニング可能 (merged_params[0] ) パラメーターと固定 (merged_params[1] ) パラメーターをマージしてパディングする

124                merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]

トレーニング可能なパラメータと固定パラメータのチャンクサイズを計算します

126                self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]

サイズをブロードキャスト

128                dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129            else:

空のテンソルを作成してサイズを受け取る

131                chunk_size = torch.tensor([0, 0], device=device)

サイズを受け取る

133                dist.broadcast(chunk_size, src=0)
134                self.chunk_size = chunk_size.tolist()

trainable (self.chunk[0] ) パラメーターと fixed () パラメーターのパラメーターを作成して、現在のデバイス/ノードに保存します self.chunk[1]

138            self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139                          for i, s in enumerate(self.chunk_size)]

トレーニング可能なパラメーターと固定パラメーターを組み合わせて受け取る空のテンソル

142            chunk = self._empty((sum(self.chunk_size),))
143
144            if rank == 0:

トレーニング可能なパラメータと固定パラメータの両方を連結する

146                all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147                del merged_params

それらをすべてのノード/デバイスに散らす

150                dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151                del all_params
152            else:

パラメータを受け取る

154                dist.scatter(chunk)

チャンクデータを収集する

157            chunk = chunk.split(self.chunk_size)
158            for i, c in enumerate(chunk):
159                self.chunk[i].data[:] = c
160            del chunk

標準パラメータをクリーンアップ

163            self._cleanup_params()

後方フックを追加。これは、モジュールを基準としたグラデーションが計算されるときに呼び出されます

166            self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook)  # type: ignore

すべてのパラメータを結合し、world_size 次の数で割り切れるようにパディングします。

168    def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:

パラメータの総数

173        size = sum(p.shape.numel() for p in params)

world_size 割り切れない場合はパディングしてください

176        if size % self.world_size != 0:
177            padding_fixed = self.world_size - (size % self.world_size)

それ以外の場合は、パッドする必要はありません

179        else:
180            padding_fixed = 0

空のパディングテンソルの作成

182        padding = self._empty((padding_fixed,))

すべてのパラメータを連結してパディングします。

184        return torch.cat([p.view(-1) for p in params] + [padding], dim=0)

トレーニング可能なパラメータのチャンク/シャードを取得します。

これを現在のノードのオプティマイザーに渡します。

186    def get_trainable_chunk(self) -> List[nn.Parameter]:

トレーニング可能なパラメータがない場合はリストを返して空にする

193        if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194            return []

トレーニング可能なチャンクをリストとして返す

197        return [self.chunk[self.TRAINING_PARAMS_IDX]]

与えられた形状の空のテンソルを作成します。

199    def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
203        return torch.empty(shape, device=self.device, dtype=self.dtype)

パラメータデータをクリーンアップする

これにより、レイヤーパラメーターで使用されていたすべてのメモリが解放されます。

205    @torch.no_grad()
206    def _cleanup_params(self):

パラメータがフェッチされないことを示すフラグを設定します。

214        self.is_fetched = False

すべてのパラメータを反復処理

217        for ps in self.param_refs:
218            for p in ps:

パラメータの操作が完了するまで待ってから、新しい操作を行います。

220                p.data.record_stream(torch.cuda.current_stream())

パラメータが他のものとストレージを共有していないことを確認してください

222                assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."

ストレージのサイズをに変更します。これにより、パラメータが使用していたメモリが解放されます。

autograd p.data グラフはメモリへの参照を保持するので、設定してもメモリは解放されません。

226                p.data.storage().resize_(0)  # This is what actually clears the memory

パラメータに勾配データがないことを確認してください

228                assert p.grad is None, 'Gradients should be None'

すべてのシャードからパラメータを取得

これにより、すべてのノードからすべてのパラメータデータが取得され、各ノードのパラメータが再構築されます。

230    @torch.no_grad()
231    def fetch_params(self):

スキップはすでに取得されています

239        if self.is_fetched:
240            return

フラグを設定

243        self.is_fetched = True

取得または共有するものがない場合はスキップしてください。

246        if sum(self.chunk_size) == 0:
247            return

fetch_stream を使用してすべてのシャードからパラメータを取得します。

250        with torch.cuda.stream(self.fetch_stream):

空のテンソルを作成してパラメーターを受け取る

252            buffer = self._empty((self.world_size * sum(self.chunk_size),))

連続バッファをノード数に分割します。これらの分割は「buffer」のビューです

254            buffers = list(buffer.split(sum(self.chunk_size)))

トレーニング可能なチャンクと固定チャンクの両方を連結する

257            chunk = torch.cat(self.chunk, dim=0)

すべてのノード/デバイスからパラメータを収集します

260            dist.all_gather(buffers, chunk)

収集したパラメーターをトレーニング可能なチャンクと固定チャンクに分割します

263            params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)

収集操作が完了するのを待ってから、バッファへの参照をクリアします。

265            buffer.record_stream(self.fetch_stream)
266            for b in buffers:
267                b.record_stream(self.fetch_stream)
268            buffer.record_stream(self.fetch_stream)
269            del buffer
270            del buffers

トレーニング可能なパラメーターと固定パラメーターを連続テンソルにリシェイプ

273            params = [p.reshape(-1) for p in params]

個々のパラメータテンソルを収集する

276            for cont, ps in zip(params, self.param_refs):

パラメータがない場合はスキップしてください

278                if not ps:
279                    continue

連続テンソルのオフセット

282                offset = 0

モデルパラメーターを繰り返し処理し、連続テンソルから値を割り当てます

284                for p in ps:

オリジナルのパラメータシェイプ

286                    shape = p._orig_shape  # type: ignore[attr-defined]

パラメータのストレージサイズを変更します。これは、パラメータをクリーンアップしたときに設定されました。

288                    p.data.storage().resize_(shape.numel())

連続テンソルから値を割り当てます

290                    p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)

操作が完了するまで待ってから、他の操作を実行してください

292                    p.data.record_stream(self.fetch_stream)

オフセットの更新

294                    offset += shape.numel()

操作が完了するまで待ってから、他の操作を実行してください

297                cont.record_stream(self.fetch_stream)

300            del params

フォワードパス

302    def forward(self, *args, **kwargs):

現在のノードのすべてのパラメータを取得します。これは前のレイヤーから呼び出されるので、この呼び出しはパラメーターがフェッチされていることを確認するためだけのものです

309        self.fetch_params()

パラメータの取得が完了するまでお待ちください。

312        torch.cuda.current_stream().wait_stream(self.fetch_stream)

処理中のレイヤーのパラメーターの取得を開始します。そうすれば、現在のレイヤーが計算を行うパラメーターが取得されます。

316        for layer in self.next_layer:
317            layer.fetch_params()

autograd が有効になっている場合は、現在のレイヤーのパラメーターに逆方向フックを追加します。

320        if torch.is_grad_enabled():
321            self._add_backward_hooks()

現在のレイヤーの出力を計算します

324        res = self.module(*args, **kwargs)

レイヤーのパラメーターをクリーンアップします。

autograd が有効になっていて、これがネットワークの最後のレイヤーである場合は、後方パスのパラメーターを再度取得する必要があるため、クリーンアップをスキップしてください。

330        if not torch.is_grad_enabled() or self.next_layer:
331            self._cleanup_params()
332
333        return res

現在のレイヤーのパラメーターに逆方向フックを追加します。

335    def _add_backward_hooks(self):

追加された後方フックの数

341        self._backward_hook_handles = 0

現在のレイヤーのトレーニング可能なパラメーターをループスルーする

344        for p in self.param_refs[self.TRAINING_PARAMS_IDX]:

フックがまだ追加されていないことを確認してください

346            assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'

expand_as インターセプトできるオートグラードのステップを作るのに使う

348            p_tmp = p.expand_as(p)

後ろ向きフックを取り付けるためのハンドルを用意してください。このブログでは、. について説明しますgrad_acc

351            grad_acc = p_tmp.grad_fn.next_functions[0][0]

後方フックを追加

353            handle = grad_acc.register_hook(
354                functools.partial(self._post_backward_hook, p))

ハンドルへの参照を忘れずに

356            p._hook_handle = handle

追加するフックの数を増やしてください

358            self._backward_hook_handles += 1

逆方向イベントの処理

これは、パラメーターの逆方向フックとモジュールの逆方向フックによって呼び出されます。

360    def _backward_event(self):

フックカウンターをデクリメントしてください

368        self._backward_hook_handles -= 1

すべてのフック (モジュールフックを含む) が呼び出されたら、グラデーションをバックアップしてパラメータをクリーンアップできます。

372        if self._backward_hook_handles == -1:
373            self._backup_grads()
374            self._cleanup_params()

autograd が次にそのレイヤーのグラデーションを処理するので、前のレイヤーのパラメーターの取得を開始します。

377        for layer in self.prev_layer:
378            layer.fetch_params()

パラメータバックワードフック

380    def _post_backward_hook(self, p: nn.Parameter, *args):

パラメータからハンドルを削除します

385        p._hook_handle.remove()  # type: ignore[attr-defined]
386        delattr(p, "_hook_handle")

逆方向イベントの処理

389        self._backward_event()

モジュール後方フック

391    def _backward_hook(self, *args, **kwargs):

逆方向イベントの処理

396        self._backward_event()

前のレイヤーがグラデーションの計算を開始します。パラメータの取得が完了したことを確認する必要があります

399        torch.cuda.current_stream().wait_stream(self.fetch_stream)

402        return None

現在のレイヤーのグラデーションをバックアップします

404    @torch.no_grad()
405    def _backup_grads(self):

トレーニング可能なパラメータがない場合はスキップ

410        if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411            return

バックアップストリームを使用してグラデーションをバックアップします

414        with torch.cuda.stream(self.backup_stream):

グラデーションを保存するバッファ

416            buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))

連続バッファを複数のノードに分割します。これらの分割は「buffer」のビューです

418            buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))

連続バッファのオフセット

421            offset = 0

トレーニング可能なパラメーターを繰り返し処理

423            for p in self.param_refs[self.TRAINING_PARAMS_IDX]:

グラデーションを集める

425                shape = p._orig_shape  # type: ignore[attr-defined]
426                buffer[offset: offset + shape.numel()] = p.grad.view(-1)

オフセットの更新

428                offset += shape.numel()

グラデーションをきれいにする

430                p.grad = None

テンソルを空にすると、現在のシャードの勾配が蓄積されます

433            grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))

各シャードのグラデーションを累積します。バッファをノード全体に分散させ、各ノードは受け取ったテンソルを蓄積(削減)します

436            dist.reduce_scatter(grad, buffers)

操作が完了するのを待ってから、バッファへの参照をクリアします。

439            for b in buffers:
440                b.record_stream(self.fetch_stream)
441            buffer.record_stream(self.fetch_stream)
442            del buffer
443            del buffers

チャンクのグラデーションを設定します。これがオプティマイザーの見解です

446            self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447            del grad

Zero3Layer レイヤー用シーケンシャルモジュール

450class Zero3Sequential(nn.Module):
  • modules Zero3Layer レイヤーのリスト
454    def __init__(self, modules: List[Zero3Layer]):
458        super().__init__()

パラメータを取得するための CUDA ストリーム

461        self.fetch_stream = torch.cuda.Stream()

グラデーションをバックアップ(蓄積)するCUDAストリーム

463        self.backup_stream = torch.cuda.Stream()

各レイヤーのストリームと先行レイヤーと後続レイヤーを設定します Zero3Layer

466        for i in range(len(modules)):

レイヤーインデックスを設定

468            modules[i].layer_idx = i

ストリームを設定

470            modules[i].fetch_stream = self.fetch_stream
471            modules[i].backup_stream = self.backup_stream

進行中のレイヤーを設定

473            if i + 1 < len(modules):
474                modules[i].next_layer.append(modules[i + 1])

前のレイヤーを設定

476            if i - 1 >= 0:
477                modules[i].prev_layer.append(modules[i - 1])

モジュール一覧を保存

480        self.module_list = nn.ModuleList(modules)
482    def get_trainable_chunk(self):

各レイヤーからトレーニング可能なチャンクのリストを返します

484        return sum([m.get_trainable_chunk() for m in self.module_list], [])
486    def forward(self, x: torch.Tensor):

グラデーションのバックアップが完了していることを確認してください

488        torch.cuda.current_stream().wait_stream(self.backup_stream)

フォワードパス

491        for m in self.module_list:
492            x = m(x)

495        return x