配备 Zero3 内存优化器的 Finetune GPT-NEO X

该脚本使用零 DP 内存优化功能在多个器件上训练 GPT-NEOX 模型的偏置参数。

14import datetime
15
16import torch
17import torch.distributed
18
19from labml import experiment, monit, tracker
20from labml.configs import option
21from labml.logger import inspect
22from labml_nn.neox.samples.finetune import PipelineParallelTrainerConf

使用 Pi peline Parallel Trainer 配置并将其调整为 Zero3 内存优化器。

27class Configs(PipelineParallelTrainerConf):
28    rank: int
29    world_size: int

设置模型的优化器

请注意,我们从传递分片参数get_trainable_chunk

32@option(Configs.optimizer, 'Zero3Adam')
33def _optimizer(c: Configs):
39    from labml_nn.optimizers.adam_fp16 import AdamFP16
40    return AdamFP16(c.model.get_trainable_chunk(), lr=c.learning_rate)

使用 Zero3 内存优化器创建模型

43@option(Configs.model, 'Zero3')
44def _model(c: Configs):
48    from labml_nn.scaling.zero3 import Zero3Layer, Zero3Sequential

确保精细调谐器设置了可训练的参数

51    _ = c.fine_tuner

将图层包裹起来Zero3Layer

54    modules = []
55    for m in monit.iterate('Zero3', c.layers):
56        modules.append(Zero3Layer(m.to(c.device),
57                                  c.rank, c.world_size, c.device, c.dtype))

创建顺序模型

60    model = Zero3Sequential(modules)

63    return model

在带有 rank 的节点上运行训练rank

66def main(rank: int, world_size: int, init_method: str = 'tcp://localhost:23456'):

初始化 PyTorch 分布式进程组

71    with monit.section('Distributed'):
72        torch.distributed.init_process_group('nccl',
73                                             timeout=datetime.timedelta(seconds=30),
74                                             init_method=init_method,
75                                             rank=rank,
76                                             world_size=world_size)

设置当前设备

79    device = torch.device(f'cuda:{rank}')
80    torch.cuda.set_device(device)

创建实验

83    experiment.create(name='zero3_neox', writers={'screen', 'labml'})
84    experiment.distributed(rank, world_size)

创建配置

87    conf = Configs()

装载配置

90    experiment.configs(conf, {
91        'model': 'Zero3',
92        'optimizer': 'Zero3Adam',
93
94        'device': device,
95        'rank': rank,
96        'world_size': world_size,
97
98        'learning_rate': 3e-4,
99        'max_seq_len': 128,
100        'batch_size': 16,
101    })

开始实验

104    with experiment.start():

初始化模型。在循环之前执行此操作以获得更清晰的日志。

106        _ = conf.model

训练模型

109        for epoch in monit.loop(conf.epochs):
110            conf.train_epoch()
111            tracker.new_line()

115if __name__ == '__main__':

记录计算机配置

117    inspect([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
118    inspect(
119        n_gpus=torch.cuda.device_count(),
120        mpi=torch.distributed.is_mpi_available(),
121        nccl=torch.distributed.is_nccl_available(),
122    )
123
124    n_gpu = torch.cuda.device_count()

为每个 GPU 启动一个进程。如果您使用多台计算机,则需要单独的启动器。

127    torch.multiprocessing.spawn(main, args=(n_gpu,), nprocs=n_gpu, join=True)