11from pathlib import Path
12from typing import Dict, Union, Tuple, Optional
13
14import torch
15from torch import nn
16
17from labml import monit, lab, logger
18from labml.logger import Text, inspect
19from labml.utils.download import download_file
保護者の URL
22CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
23
24_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
ダウンロードパス
28def get_checkpoints_download_path():
29 global _CHECKPOINTS_DOWNLOAD_PATH
30
31 if _CHECKPOINTS_DOWNLOAD_PATH is not None:
32 return _CHECKPOINTS_DOWNLOAD_PATH
33
34 _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
35 if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
36 _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
37 inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
38
39 return _CHECKPOINTS_DOWNLOAD_PATH
42def get_files_to_download(n_layers: int = 44):
48 layers = (
埋め込みレイヤー
50 [0] +
トランスフォーマー層
52 list(range(2, 2 + n_layers)) +
最終正規化層と読み出し層
54 [47, 48]
55 )
56
57 return (
語彙と構成
59 ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +
レイヤーチェックポイント
61 [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +
空の状態 (未使用)
63 [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
64 )
67def download(n_layers: int = 44):
ダウンロードするファイルを取得
73 files = get_files_to_download(n_layers)
繰り返し
76 for i, f in monit.enum('Download All', files):
ログ
78 logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
[ダウンロード]
80 download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
83def load_checkpoint_files(files: Tuple[str, str]):
90 checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
91 with monit.section('Load checkpoint'):
92 data = [torch.load(checkpoint_path / f) for f in files]
93
94 return data
param
はパラメータですkey
はパラメータの名前ですp1
第 1 パーティション辞書p2
2 番目のパーティション辞書97def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
98 p2: Dict[str, torch.Tensor]):
107 w1, w2 = p1[key], p2[key]
108 param.data[:w1.shape[0]] = w1
109 param.data[w1.shape[0]:] = w2
param
はパラメータですkey
はパラメータの名前ですp1
第 1 パーティション辞書p2
2 番目のパーティション辞書112def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
113 p2: Dict[str, torch.Tensor]):
122 w1, w2 = p1[key], p2[key]
123 param.data[:, :w1.shape[1]] = w1
124 param.data[:, w1.shape[1]:] = w2
これにより、両方のパーティションが同じであることを確認するためのサニティチェックが行われます。
param
はパラメータですkey
はパラメータの名前ですp1
第 1 パーティション辞書p2
2 番目のパーティション辞書127def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
128 p2: Dict[str, torch.Tensor]):
139 w1, w2 = p1[key], p2[key]
140
141 diff = sum((w1 - w2) ** 2).item()
142 assert diff < 1e-4, f'The partitions do not match: {key}'
143
144 param.data[:] = (w1 + w2) / 2.
147def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
148 p2: Dict[str, torch.Tensor]):
157 w1, w2 = p1[key], p2[key]
158
159 param.data[:] = w1 + w2
163if __name__ == '__main__':
164 download()