GPT-NeoX Checkpoints

11from pathlib import Path
12from typing import Dict, Union, Tuple, Optional
13
14import torch
15from torch import nn
16
17from labml import monit, lab, logger
18from labml.logger import Text, inspect
19from labml.utils.download import download_file

Parent url

22CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
23
24_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None

Download path

28def get_checkpoints_download_path():
29    global _CHECKPOINTS_DOWNLOAD_PATH
30
31    if _CHECKPOINTS_DOWNLOAD_PATH is not None:
32        return _CHECKPOINTS_DOWNLOAD_PATH
33
34    _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
35    if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
36        _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
37    inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
38
39    return _CHECKPOINTS_DOWNLOAD_PATH

Get files to download

    Returns a list of files to be downloaded

42def get_files_to_download(n_layers: int = 44):
48    layers = (

Embedding layer

50            [0] +

Transformer layers

52            list(range(2, 2 + n_layers)) +

Final normalization layer and readout layer

54            [47, 48]
55    )
56
57    return (

Vocabulary and configs

59            ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +

Layer checkpoints

61            [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +

Empty states (not used)

63            [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
64    )

Download all checkpoint files

67def download(n_layers: int = 44):

Get files to download

73    files = get_files_to_download(n_layers)

Iterate

76    for i, f in monit.enum('Download All', files):

Log

78        logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])

Download

80        download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)

Load a pair of checkpoint files

  • files pair of files to load
  • Returns the loaded parameter tensors

83def load_checkpoint_files(files: Tuple[str, str]):
90    checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
91    with monit.section('Load checkpoint'):
92        data = [torch.load(checkpoint_path / f) for f in files]
93
94    return data

Load a parameter by merging the partitions along first dimension

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
97def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
98                       p2: Dict[str, torch.Tensor]):
107    w1, w2 = p1[key], p2[key]
108    param.data[:w1.shape[0]] = w1
109    param.data[w1.shape[0]:] = w2

Load a parameter by merging the partitions along second dimension

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
112def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
113                       p2: Dict[str, torch.Tensor]):
122    w1, w2 = p1[key], p2[key]
123    param.data[:, :w1.shape[1]] = w1
124    param.data[:, w1.shape[1]:] = w2

Load an un-partitioned parameter

This does a sanity check to make use both partitions are the same

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
127def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
128                           p2: Dict[str, torch.Tensor]):
139    w1, w2 = p1[key], p2[key]
140
141    diff = sum((w1 - w2) ** 2).item()
142    assert diff < 1e-4, f'The partitions do not match: {key}'
143
144    param.data[:] = (w1 + w2) / 2.

Load biases that are partitioned which gets added on reduce

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
147def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
148                     p2: Dict[str, torch.Tensor]):
157    w1, w2 = p1[key], p2[key]
158
159    param.data[:] = w1 + w2

163if __name__ == '__main__':
164    download()