14import datetime
15
16import torch
17import torch.distributed
18
19from labml import experiment, monit, tracker
20from labml.configs import option
21from labml.logger import inspect
22from labml_nn.neox.samples.finetune import PipelineParallelTrainerConf
使用 Pi peline Parallel Trainer 配置并将其调整为 Zero3 内存优化器。
27class Configs(PipelineParallelTrainerConf):
28 rank: int
29 world_size: int
32@option(Configs.optimizer, 'Zero3Adam')
33def _optimizer(c: Configs):
39 from labml_nn.optimizers.adam_fp16 import AdamFP16
40 return AdamFP16(c.model.get_trainable_chunk(), lr=c.learning_rate)
43@option(Configs.model, 'Zero3')
44def _model(c: Configs):
48 from labml_nn.scaling.zero3 import Zero3Layer, Zero3Sequential
确保精细调谐器设置了可训练的参数
51 _ = c.fine_tuner
将图层包裹起来Zero3Layer
54 modules = []
55 for m in monit.iterate('Zero3', c.layers):
56 modules.append(Zero3Layer(m.to(c.device),
57 c.rank, c.world_size, c.device, c.dtype))
创建顺序模型
60 model = Zero3Sequential(modules)
63 return model
rank
。66def main(rank: int, world_size: int, init_method: str = 'tcp://localhost:23456'):
初始化 PyTorch 分布式进程组
71 with monit.section('Distributed'):
72 torch.distributed.init_process_group('nccl',
73 timeout=datetime.timedelta(seconds=30),
74 init_method=init_method,
75 rank=rank,
76 world_size=world_size)
设置当前设备
79 device = torch.device(f'cuda:{rank}')
80 torch.cuda.set_device(device)
创建实验
83 experiment.create(name='zero3_neox', writers={'screen', 'labml'},
84 distributed_world_size=world_size,
85 distributed_rank=rank)
创建配置
88 conf = Configs()
装载配置
91 experiment.configs(conf, {
92 'model': 'Zero3',
93 'optimizer': 'Zero3Adam',
94
95 'device': device,
96 'rank': rank,
97 'world_size': world_size,
98
99 'learning_rate': 3e-4,
100 'max_seq_len': 128,
101 'batch_size': 16,
102 })
开始实验
105 with experiment.start():
初始化模型。在循环之前执行此操作以获得更清晰的日志。
107 _ = conf.model
训练模型
110 for epoch in monit.loop(conf.epochs):
111 conf.train_epoch()
112 tracker.new_line()
116if __name__ == '__main__':
记录计算机配置
118 inspect([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
119 inspect(
120 n_gpus=torch.cuda.device_count(),
121 mpi=torch.distributed.is_mpi_available(),
122 nccl=torch.distributed.is_nccl_available(),
123 )
124
125 n_gpu = torch.cuda.device_count()
为每个 GPU 启动一个进程。如果您使用多台计算机,则需要单独的启动器。
128 torch.multiprocessing.spawn(main, args=(n_gpu,), nprocs=n_gpu, join=True)