Fine Tune GPT-NeoX

This shows how to fine tune GPT-NeoX with pipeline parallelism.

13import fairscale
14import torch
15import torch.nn as nn
16import torch.utils.data
17import torch.utils.data
18import typing
19from torch.utils.data import DataLoader, RandomSampler
20
21from labml import experiment, monit, tracker, lab
22from labml.configs import option
23from labml.logger import inspect
24from labml_nn.neox.utils.text_dataset import get_training_data
25from labml_nn.neox.utils.finetune import FineTuneBiases
26from labml_nn.neox.model import LayerGenerator, NeoXModule
27from labml_nn.neox.utils import balance_layers_simple
28from labml_nn.neox.utils.trainer import PipelineParallelTrainerConf

Load GPT-NeoX layers

31@option(PipelineParallelTrainerConf.layers, 'PipelineBiases')
32def neox_layers(c: PipelineParallelTrainerConf):
36    return list(LayerGenerator(is_clone_layers=c.is_clone_layers,
37                               filter_layers=c.filter_layers,
38                               dtype=c.dtype,
39                               ).load())

Create fine tuner for biases

42@option(PipelineParallelTrainerConf.fine_tuner, 'PipelineBiases')
43def fine_tune_biases(c: PipelineParallelTrainerConf):
48    fine_tuner = FineTuneBiases(typing.cast(typing.List[NeoXModule], c.layers))

Mark biases as trainable

50    fine_tuner.set_trainable_params()

53    return fine_tuner

Create pipeline parallel model

56@option(PipelineParallelTrainerConf.model, 'PipelineBiases')
57def pipe_model(c: PipelineParallelTrainerConf):
62    if c.is_checkpointing:
63        raise NotImplementedError()
64    else:
65        layers = c.layers

Make sure the finetuner is initialized

68    _ = c.fine_tuner

Create the Pipe module

71    with monit.section('Pipe'):

Get the layer distribution across GPUs

73        balance = balance_layers_simple(len(layers), c.n_gpus)
74        inspect(balance=balance)

Devices for each GPU

76        devices = [torch.device(f'cuda:{i}') for i in range(c.n_gpus)]

Create Fairscale Pipe module

78        pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
79                                       balance=balance,
80                                       devices=devices,
81                                       chunks=c.chunks)

84    return pipe_model

Tiny Shakespeare dataset

87@option(PipelineParallelTrainerConf.train_loader)
88def tiny_shakespeare(c: PipelineParallelTrainerConf):
92    dataset = get_training_data(c.max_seq_len)
93
94    return DataLoader(dataset,
95                      batch_size=c.batch_size,
96                      sampler=RandomSampler(dataset, replacement=True))
99def main():

Create experiment

101    experiment.create(name='pipe_neox_biases',
102                      writers={'screen', 'web_api'})

Initialize configs

105    conf = PipelineParallelTrainerConf()
106    experiment.configs(conf, {
107        'learning_rate': 3e-4,
108        'is_checkpointing': False,
109        'max_seq_len': 128,
110        'batch_size': 64,
111        'chunks': 8,
112    })

Start the experiment

115    with experiment.start():

Initialize the model. Do this before the loop for cleaner logs.

117        _ = conf.model

Train

120        for epoch in monit.loop(conf.epochs):
121            conf.train_epoch()
122            tracker.new_line()
123            torch.save(conf.fine_tuner.state_dict(), str(lab.get_data_path() / 'fine_tune.pt'))

127if __name__ == '__main__':
128    main()