GPT-neox 检查点

11from pathlib import Path
12from typing import Dict, Union, Tuple, Optional
13
14import torch
15from torch import nn
16
17from labml import monit, lab, logger
18from labml.logger import Text, inspect
19from labml.utils.download import download_file

家长网址

22CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
23
24_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None

下载路径

28def get_checkpoints_download_path():
29    global _CHECKPOINTS_DOWNLOAD_PATH
30
31    if _CHECKPOINTS_DOWNLOAD_PATH is not None:
32        return _CHECKPOINTS_DOWNLOAD_PATH
33
34    _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
35    if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
36        _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
37    inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
38
39    return _CHECKPOINTS_DOWNLOAD_PATH

获取要下载的文件

    返回要下载的文件列表

42def get_files_to_download(n_layers: int = 44):
48    layers = (

嵌入层

50            [0] +

变压器层

52            list(range(2, 2 + n_layers)) +

最终归一化层和读出层

54            [47, 48]
55    )
56
57    return (

词汇和配置

59            ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +

图层检查点

61            [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +

空状态(未使用)

63            [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
64    )

下载所有检查点文件

67def download(n_layers: int = 44):

获取要下载的文件

73    files = get_files_to_download(n_layers)

迭代

76    for i, f in monit.enum('Download All', files):

日志

78        logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])

下载

80        download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)

加载一对检查点文件

  • files 一对要加载的文件
  • 返回加载的参数张量

83def load_checkpoint_files(files: Tuple[str, str]):
90    checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
91    with monit.section('Load checkpoint'):
92        data = [torch.load(checkpoint_path / f) for f in files]
93
94    return data

通过合并沿第一维度的分区来加载参数

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
97def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
98                       p2: Dict[str, torch.Tensor]):
107    w1, w2 = p1[key], p2[key]
108    param.data[:w1.shape[0]] = w1
109    param.data[w1.shape[0]:] = w2

通过合并第二维度的分区来加载参数

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
112def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
113                       p2: Dict[str, torch.Tensor]):
122    w1, w2 = p1[key], p2[key]
123    param.data[:, :w1.shape[1]] = w1
124    param.data[:, w1.shape[1]:] = w2

加载未分区的参数

这会进行健全性检查,以使用两个分区是相同的

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
127def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
128                           p2: Dict[str, torch.Tensor]):
139    w1, w2 = p1[key], p2[key]
140
141    diff = sum((w1 - w2) ** 2).item()
142    assert diff < 1e-4, f'The partitions do not match: {key}'
143
144    param.data[:] = (w1 + w2) / 2.

分区的负载偏差在 reduce 时被添加

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
147def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
148                     p2: Dict[str, torch.Tensor]):
157    w1, w2 = p1[key], p2[key]
158
159    param.data[:] = w1 + w2

163if __name__ == '__main__':
164    download()