GPT-ネオックスチェックポイント

11from pathlib import Path
12from typing import Dict, Union, Tuple, Optional
13
14import torch
15from torch import nn
16
17from labml import monit, lab, logger
18from labml.logger import Text, inspect
19from labml.utils.download import download_file

保護者の URL

22CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
23
24_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None

ダウンロードパス

28def get_checkpoints_download_path():
29    global _CHECKPOINTS_DOWNLOAD_PATH
30
31    if _CHECKPOINTS_DOWNLOAD_PATH is not None:
32        return _CHECKPOINTS_DOWNLOAD_PATH
33
34    _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
35    if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
36        _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
37    inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
38
39    return _CHECKPOINTS_DOWNLOAD_PATH

ダウンロードするファイルを取得

    ダウンロードするファイルのリストを返します

42def get_files_to_download(n_layers: int = 44):
48    layers = (

埋め込みレイヤー

50            [0] +

トランスフォーマー層

52            list(range(2, 2 + n_layers)) +

最終正規化層と読み出し層

54            [47, 48]
55    )
56
57    return (

語彙と構成

59            ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +

レイヤーチェックポイント

61            [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +

空の状態 (未使用)

63            [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
64    )

すべてのチェックポイントファイルをダウンロード

67def download(n_layers: int = 44):

ダウンロードするファイルを取得

73    files = get_files_to_download(n_layers)

繰り返し

76    for i, f in monit.enum('Download All', files):

ログ

78        logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])

[ダウンロード]

80        download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)

1 組のチェックポイントファイルをロードする

  • files ロードするファイルのペア

ロードされたパラメータテンソルを返します

83def load_checkpoint_files(files: Tuple[str, str]):
90    checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
91    with monit.section('Load checkpoint'):
92        data = [torch.load(checkpoint_path / f) for f in files]
93
94    return data

1 番目の次元に沿ってパーティションをマージしてパラメータをロードします

  • param はパラメータです
  • key はパラメータの名前です
  • p1 第 1 パーティション辞書
  • p2 2 番目のパーティション辞書
97def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
98                       p2: Dict[str, torch.Tensor]):
107    w1, w2 = p1[key], p2[key]
108    param.data[:w1.shape[0]] = w1
109    param.data[w1.shape[0]:] = w2

2 番目の次元に沿ってパーティションをマージしてパラメータをロードします

  • param はパラメータです
  • key はパラメータの名前です
  • p1 第 1 パーティション辞書
  • p2 2 番目のパーティション辞書
112def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
113                       p2: Dict[str, torch.Tensor]):
122    w1, w2 = p1[key], p2[key]
123    param.data[:, :w1.shape[1]] = w1
124    param.data[:, w1.shape[1]:] = w2

パーティション化されていないパラメータをロードする

これにより、両方のパーティションが同じであることを確認するためのサニティチェックが行われます。

  • param はパラメータです
  • key はパラメータの名前です
  • p1 第 1 パーティション辞書
  • p2 2 番目のパーティション辞書
127def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
128                           p2: Dict[str, torch.Tensor]):
139    w1, w2 = p1[key], p2[key]
140
141    diff = sum((w1 - w2) ** 2).item()
142    assert diff < 1e-4, f'The partitions do not match: {key}'
143
144    param.data[:] = (w1 + w2) / 2.

分割された負荷バイアスがリデュース時に追加される

  • param はパラメータです
  • key はパラメータの名前です
  • p1 第 1 パーティション辞書
  • p2 2 番目のパーティション辞書
147def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
148                     p2: Dict[str, torch.Tensor]):
157    w1, w2 = p1[key], p2[key]
158
159    param.data[:] = w1 + w2

163if __name__ == '__main__':
164    download()