13import fairscale
14import torch
15import torch.nn as nn
16import torch.utils.data
17import torch.utils.data
18import typing
19from torch.utils.data import DataLoader, RandomSampler
20
21from labml import experiment, monit, tracker, lab
22from labml.configs import option
23from labml.logger import inspect
24from labml_nn.neox.utils.text_dataset import get_training_data
25from labml_nn.neox.utils.finetune import FineTuneBiases
26from labml_nn.neox.model import LayerGenerator, NeoXModule
27from labml_nn.neox.utils import balance_layers_simple
28from labml_nn.neox.utils.trainer import PipelineParallelTrainerConf
31@option(PipelineParallelTrainerConf.layers, 'PipelineBiases')
32def neox_layers(c: PipelineParallelTrainerConf):
36 return list(LayerGenerator(is_clone_layers=c.is_clone_layers,
37 filter_layers=c.filter_layers,
38 dtype=c.dtype,
39 ).load())
42@option(PipelineParallelTrainerConf.fine_tuner, 'PipelineBiases')
43def fine_tune_biases(c: PipelineParallelTrainerConf):
48 fine_tuner = FineTuneBiases(typing.cast(typing.List[NeoXModule], c.layers))
将偏见标记为可训练
50 fine_tuner.set_trainable_params()
53 return fine_tuner
56@option(PipelineParallelTrainerConf.model, 'PipelineBiases')
57def pipe_model(c: PipelineParallelTrainerConf):
62 if c.is_checkpointing:
63 raise NotImplementedError()
64 else:
65 layers = c.layers
确保微调器已初始化
68 _ = c.fine_tuner
创建管道模块
71 with monit.section('Pipe'):
获取跨 GPU 的层分布
73 balance = balance_layers_simple(len(layers), c.n_gpus)
74 inspect(balance=balance)
每个 GPU 的设备
76 devices = [torch.device(f'cuda:{i}') for i in range(c.n_gpus)]
创建公平规模管道模块
78 pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
79 balance=balance,
80 devices=devices,
81 chunks=c.chunks)
84 return pipe_model
87@option(PipelineParallelTrainerConf.train_loader)
88def tiny_shakespeare(c: PipelineParallelTrainerConf):
92 dataset = get_training_data(c.max_seq_len)
93
94 return DataLoader(dataset,
95 batch_size=c.batch_size,
96 sampler=RandomSampler(dataset, replacement=True))
99def main():
创建实验
101 experiment.create(name='pipe_neox_biases',
102 writers={'screen', 'web_api'})
初始化配置
105 conf = PipelineParallelTrainerConf()
106 experiment.configs(conf, {
107 'learning_rate': 3e-4,
108 'is_checkpointing': False,
109 'max_seq_len': 128,
110 'batch_size': 64,
111 'chunks': 8,
112 })
开始实验
115 with experiment.start():
初始化模型。在循环之前执行此操作以获得更清晰的日志。
117 _ = conf.model
火车
120 for epoch in monit.loop(conf.epochs):
121 conf.train_epoch()
122 tracker.new_line()
123 torch.save(conf.fine_tuner.state_dict(), str(lab.get_data_path() / 'fine_tune.pt'))
127if __name__ == '__main__':
128 main()