これは、論文「Zero: 1兆のパラメーターモデルのトレーニングに向けたメモリ最適化」で紹介されているゼロDPの実装です。
オプティマイザの状態、グラデーション、パラメータの断片を複数のデバイス/ノードに保持します。これにより、メモリ消費量が元のモデルと同じになります。ここで、はパラメーターの数、はシャードの数、パラメーターごとのオプティマイザーのバイト数です。は、16 ビットの精度を前提としたパラメーターとグラデーションのメモリです。つまり、パラメーターとグラデーションごとに 2 バイトです。Adam オプティマイザー用です。これは、パラメーターのコピーと fp32 のパラメーターごとに 2 つのモーメントを保持しているためです
。ゼロDPの通信量は。比較のためにデータ並行トレーニングの通信量は
.これは名前が付けられていますがZero3
、残留メモリ消費を対象とするゼロRメモリ最適化は実装しておらず、DPがゼロの部分のみを実装しています。この実装では、パラメータのサブセットのみのトレーニングをサポートしています
32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn
40class Zero3Layer(nn.Module):
chunk
各シャードはパラメータをリストに保持します。chunk[0]
はトレーニング可能なパラメーター用で、chunk[1]
固定パラメーター用です
49 chunk: List[nn.Parameter]
chunk
これはリスト内のチャンクのサイズです。
51 chunk_size: List[int]
最初のチャンクはトレーニング可能なパラメーター用です。
53 TRAINING_PARAMS_IDX = 0
これは、トレーニング可能なパラメーターと固定パラメーターとしてリストに分割されたパラメーターのリストです。
56 param_refs: List[List[nn.Parameter]]
パラメータを取得する CUDA ストリーム
59 fetch_stream: Optional[torch.cuda.Stream]
勾配をバックアップ/蓄積するためのCUDAストリーム
61 backup_stream: Optional[torch.cuda.Stream]
このレイヤーの直前のレイヤーのリスト
63 prev_layer: List['Zero3Layer']
このレイヤーの直後のレイヤーのリスト
65 next_layer: List['Zero3Layer']
現在のレイヤーの位置。これをログのデバッグに使用しました
67 layer_idx: int
パラメータが取得されているかどうか
70 is_fetched: bool
レイヤーのデバイス
73 device: torch.device
レイヤーのデータタイプ
75 dtype: torch.dtype
ラップするモジュール
77 module: nn.Module
データがシャーディングされるノード/デバイスの数
79 world_size: int
module
ラップするモジュール。rank
現在のノードのランク。world_size
データがシャーディングされるノード/デバイスの数。device
レイヤーのデバイス。dtype
レイヤーのデータタイプ。81 def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
89 super().__init__()
プロパティを初期化
92 self.device = device
93 self.dtype = dtype
94 self.module = module
95 self.prev_layer = []
96 self.next_layer = []
97 self.is_fetched = False
98 self.world_size = world_size
99 self.layer_idx = -1
100 self.fetch_stream = None
101 self.backup_stream = None
102
103 with torch.no_grad():
レイヤーのすべてのパラメーターを収集します。
105 all_param_refs = [p for p in self.parameters()]
パラメータの形状を保存しておきます。後で再構築する必要があるからです。
108 for p in all_param_refs:
109 p._orig_shape = p.shape
すべてのパラメータは同じタイプでなければなりません
112 for p in all_param_refs:
113 assert p.dtype == dtype, "All parameters should have same dtype"
トレーニング可能なパラメータと固定パラメータを分離
116 self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117 [p for p in all_param_refs if not p.requires_grad]]
118 del all_param_refs
rank = 0
ノードは、各デバイス/ノードが保存するサイズを計算し、それに応じてパラメータを分散します。
122 if rank == 0:
トレーニング可能 (merged_params[0]
) パラメーターと固定 (merged_params[1]
) パラメーターをマージしてパディングする
124 merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]
トレーニング可能なパラメータと固定パラメータのチャンクサイズを計算します
126 self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]
サイズをブロードキャスト
128 dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129 else:
空のテンソルを作成してサイズを受け取る
131 chunk_size = torch.tensor([0, 0], device=device)
サイズを受け取る
133 dist.broadcast(chunk_size, src=0)
134 self.chunk_size = chunk_size.tolist()
trainable (self.chunk[0]
) パラメーターと fixed () パラメーターのパラメーターを作成して、現在のデバイス/ノードに保存します self.chunk[1]
138 self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139 for i, s in enumerate(self.chunk_size)]
トレーニング可能なパラメーターと固定パラメーターを組み合わせて受け取る空のテンソル
142 chunk = self._empty((sum(self.chunk_size),))
143
144 if rank == 0:
トレーニング可能なパラメータと固定パラメータの両方を連結する
146 all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147 del merged_params
それらをすべてのノード/デバイスに散らす
150 dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151 del all_params
152 else:
パラメータを受け取る
154 dist.scatter(chunk)
チャンクデータを収集する
157 chunk = chunk.split(self.chunk_size)
158 for i, c in enumerate(chunk):
159 self.chunk[i].data[:] = c
160 del chunk
標準パラメータをクリーンアップ
163 self._cleanup_params()
後方フックを追加。これは、モジュールを基準としたグラデーションが計算されるときに呼び出されます
。166 self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook) # type: ignore
world_size
次の数で割り切れるようにパディングします。168 def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:
パラメータの総数
173 size = sum(p.shape.numel() for p in params)
world_size
割り切れない場合はパディングしてください
176 if size % self.world_size != 0:
177 padding_fixed = self.world_size - (size % self.world_size)
それ以外の場合は、パッドする必要はありません
179 else:
180 padding_fixed = 0
空のパディングテンソルの作成
182 padding = self._empty((padding_fixed,))
すべてのパラメータを連結してパディングします。
184 return torch.cat([p.view(-1) for p in params] + [padding], dim=0)
186 def get_trainable_chunk(self) -> List[nn.Parameter]:
トレーニング可能なパラメータがない場合はリストを返して空にする
193 if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194 return []
トレーニング可能なチャンクをリストとして返す
197 return [self.chunk[self.TRAINING_PARAMS_IDX]]
199 def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
203 return torch.empty(shape, device=self.device, dtype=self.dtype)
205 @torch.no_grad()
206 def _cleanup_params(self):
パラメータがフェッチされないことを示すフラグを設定します。
214 self.is_fetched = False
すべてのパラメータを反復処理
217 for ps in self.param_refs:
218 for p in ps:
パラメータの操作が完了するまで待ってから、新しい操作を行います。
220 p.data.record_stream(torch.cuda.current_stream())
パラメータが他のものとストレージを共有していないことを確認してください
222 assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."
ストレージのサイズをに変更します。これにより、パラメータが使用していたメモリが解放されます。
autograd p.data
グラフはメモリへの参照を保持するので、設定してもメモリは解放されません。
226 p.data.storage().resize_(0) # This is what actually clears the memory
パラメータに勾配データがないことを確認してください
228 assert p.grad is None, 'Gradients should be None'
230 @torch.no_grad()
231 def fetch_params(self):
スキップはすでに取得されています
239 if self.is_fetched:
240 return
フラグを設定
243 self.is_fetched = True
取得または共有するものがない場合はスキップしてください。
246 if sum(self.chunk_size) == 0:
247 return
fetch_stream
を使用してすべてのシャードからパラメータを取得します。
250 with torch.cuda.stream(self.fetch_stream):
空のテンソルを作成してパラメーターを受け取る
252 buffer = self._empty((self.world_size * sum(self.chunk_size),))
連続バッファをノード数に分割します。これらの分割は「buffer」のビューです
。254 buffers = list(buffer.split(sum(self.chunk_size)))
トレーニング可能なチャンクと固定チャンクの両方を連結する
257 chunk = torch.cat(self.chunk, dim=0)
すべてのノード/デバイスからパラメータを収集します
260 dist.all_gather(buffers, chunk)
収集したパラメーターをトレーニング可能なチャンクと固定チャンクに分割します
263 params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)
収集操作が完了するのを待ってから、バッファへの参照をクリアします。
265 buffer.record_stream(self.fetch_stream)
266 for b in buffers:
267 b.record_stream(self.fetch_stream)
268 buffer.record_stream(self.fetch_stream)
269 del buffer
270 del buffers
トレーニング可能なパラメーターと固定パラメーターを連続テンソルにリシェイプ
273 params = [p.reshape(-1) for p in params]
個々のパラメータテンソルを収集する
276 for cont, ps in zip(params, self.param_refs):
パラメータがない場合はスキップしてください
278 if not ps:
279 continue
連続テンソルのオフセット
282 offset = 0
モデルパラメーターを繰り返し処理し、連続テンソルから値を割り当てます
284 for p in ps:
オリジナルのパラメータシェイプ
286 shape = p._orig_shape # type: ignore[attr-defined]
パラメータのストレージサイズを変更します。これは、パラメータをクリーンアップしたときに設定されました。
288 p.data.storage().resize_(shape.numel())
連続テンソルから値を割り当てます
290 p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)
操作が完了するまで待ってから、他の操作を実行してください
292 p.data.record_stream(self.fetch_stream)
オフセットの更新
294 offset += shape.numel()
操作が完了するまで待ってから、他の操作を実行してください
297 cont.record_stream(self.fetch_stream)
300 del params
302 def forward(self, *args, **kwargs):
現在のノードのすべてのパラメータを取得します。これは前のレイヤーから呼び出されるので、この呼び出しはパラメーターがフェッチされていることを確認するためだけのものです
。309 self.fetch_params()
パラメータの取得が完了するまでお待ちください。
312 torch.cuda.current_stream().wait_stream(self.fetch_stream)
処理中のレイヤーのパラメーターの取得を開始します。そうすれば、現在のレイヤーが計算を行うパラメーターが取得されます。
316 for layer in self.next_layer:
317 layer.fetch_params()
autograd が有効になっている場合は、現在のレイヤーのパラメーターに逆方向フックを追加します。
320 if torch.is_grad_enabled():
321 self._add_backward_hooks()
現在のレイヤーの出力を計算します
324 res = self.module(*args, **kwargs)
レイヤーのパラメーターをクリーンアップします。
autograd が有効になっていて、これがネットワークの最後のレイヤーである場合は、後方パスのパラメーターを再度取得する必要があるため、クリーンアップをスキップしてください。
330 if not torch.is_grad_enabled() or self.next_layer:
331 self._cleanup_params()
332
333 return res
335 def _add_backward_hooks(self):
追加された後方フックの数
341 self._backward_hook_handles = 0
現在のレイヤーのトレーニング可能なパラメーターをループスルーする
344 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
フックがまだ追加されていないことを確認してください
346 assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'
expand_as
インターセプトできるオートグラードのステップを作るのに使う
348 p_tmp = p.expand_as(p)
後ろ向きフックを取り付けるためのハンドルを用意してください。このブログでは、. について説明しますgrad_acc
351 grad_acc = p_tmp.grad_fn.next_functions[0][0]
後方フックを追加
353 handle = grad_acc.register_hook(
354 functools.partial(self._post_backward_hook, p))
ハンドルへの参照を忘れずに
356 p._hook_handle = handle
追加するフックの数を増やしてください
358 self._backward_hook_handles += 1
360 def _backward_event(self):
フックカウンターをデクリメントしてください
368 self._backward_hook_handles -= 1
すべてのフック (モジュールフックを含む) が呼び出されたら、グラデーションをバックアップしてパラメータをクリーンアップできます。
372 if self._backward_hook_handles == -1:
373 self._backup_grads()
374 self._cleanup_params()
autograd が次にそのレイヤーのグラデーションを処理するので、前のレイヤーのパラメーターの取得を開始します。
377 for layer in self.prev_layer:
378 layer.fetch_params()
380 def _post_backward_hook(self, p: nn.Parameter, *args):
パラメータからハンドルを削除します
385 p._hook_handle.remove() # type: ignore[attr-defined]
386 delattr(p, "_hook_handle")
逆方向イベントの処理
389 self._backward_event()
391 def _backward_hook(self, *args, **kwargs):
逆方向イベントの処理
396 self._backward_event()
前のレイヤーがグラデーションの計算を開始します。パラメータの取得が完了したことを確認する必要があります
。399 torch.cuda.current_stream().wait_stream(self.fetch_stream)
402 return None
404 @torch.no_grad()
405 def _backup_grads(self):
トレーニング可能なパラメータがない場合はスキップ
410 if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411 return
バックアップストリームを使用してグラデーションをバックアップします
414 with torch.cuda.stream(self.backup_stream):
グラデーションを保存するバッファ
416 buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))
連続バッファを複数のノードに分割します。これらの分割は「buffer」のビューです
。418 buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))
連続バッファのオフセット
421 offset = 0
トレーニング可能なパラメーターを繰り返し処理
423 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
グラデーションを集める
425 shape = p._orig_shape # type: ignore[attr-defined]
426 buffer[offset: offset + shape.numel()] = p.grad.view(-1)
オフセットの更新
428 offset += shape.numel()
グラデーションをきれいにする
430 p.grad = None
テンソルを空にすると、現在のシャードの勾配が蓄積されます
433 grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))
各シャードのグラデーションを累積します。バッファをノード全体に分散させ、各ノードは受け取ったテンソルを蓄積(削減)します
。436 dist.reduce_scatter(grad, buffers)
操作が完了するのを待ってから、バッファへの参照をクリアします。
439 for b in buffers:
440 b.record_stream(self.fetch_stream)
441 buffer.record_stream(self.fetch_stream)
442 del buffer
443 del buffers
チャンクのグラデーションを設定します。これがオプティマイザーの見解です
。446 self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447 del grad
Zero3Layer
レイヤー用シーケンシャルモジュール450class Zero3Sequential(nn.Module):
modules
Zero3Layer
レイヤーのリスト454 def __init__(self, modules: List[Zero3Layer]):
458 super().__init__()
パラメータを取得するための CUDA ストリーム
461 self.fetch_stream = torch.cuda.Stream()
グラデーションをバックアップ(蓄積)するCUDAストリーム
463 self.backup_stream = torch.cuda.Stream()
各レイヤーのストリームと先行レイヤーと後続レイヤーを設定します Zero3Layer
466 for i in range(len(modules)):
レイヤーインデックスを設定
468 modules[i].layer_idx = i
ストリームを設定
470 modules[i].fetch_stream = self.fetch_stream
471 modules[i].backup_stream = self.backup_stream
進行中のレイヤーを設定
473 if i + 1 < len(modules):
474 modules[i].next_layer.append(modules[i + 1])
前のレイヤーを設定
476 if i - 1 >= 0:
477 modules[i].prev_layer.append(modules[i - 1])
モジュール一覧を保存
480 self.module_list = nn.ModuleList(modules)
482 def get_trainable_chunk(self):
各レイヤーからトレーニング可能なチャンクのリストを返します
484 return sum([m.get_trainable_chunk() for m in self.module_list], [])
486 def forward(self, x: torch.Tensor):
グラデーションのバックアップが完了していることを確認してください
488 torch.cuda.current_stream().wait_stream(self.backup_stream)
フォワードパス
491 for m in self.module_list:
492 x = m(x)
495 return x