Finetune GPT-2 with LoRA

Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab

14import torch
15from torch.optim import Adam
16from torch.utils.data import DataLoader, TensorDataset
17from transformers import AutoTokenizer, AutoModelForCausalLM
18
19from labml import lab, monit, tracker
20from labml.configs import BaseConfigs, option
21from labml.utils.download import download_file
22from labml_helpers.device import DeviceConfigs
23from labml_nn.lora.gpt2 import GPTModel

Trainer configurations and the training loop

The default configs can and will be over-ridden when we start the experiment

26class Trainer(BaseConfigs):
32    device: torch.device = DeviceConfigs()

GPT-2 configs

35    layer_norm_epsilon: float = 1e-05
36    d_model: int = 768
37    n_layers: int = 12
38    n_heads: int = 12
39    n_positions: int = 1024
40    vocab_size: int = 50257

Training configs

43    epochs: int = 10
44    batch_size: int = 32
45    learning_rate: float = 1e-4
46    context_len: int = 512

LoRA rank

49    lora_r: int = 32

Dataset

52    text: TensorDataset = "tiny_shakespeare"

Huggingface tokenizer

54    tokenizer = AutoTokenizer.from_pretrained("gpt2")
56    model: GPTModel

Optimizer

58    optimizer: torch.optim.Adam

Cross entropy loss

60    loss_func = torch.nn.CrossEntropyLoss()

Dataloader

62    data_loader: DataLoader

Load pre-trained GPT-2 from huggingface

64    def _load_pretrained_weights(self):

Load the huggingface model and get the parameters

70        hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
71        state_dict = hf_model.state_dict()

Transformer embedding and prediction layer parameter mapping (hf: ours )

74        mapping = {
75            'transformer.wte.weight': 'token_embedding.weight',
76            'transformer.wpe.weight': 'position_embedding.weight',
77            'transformer.ln_f.weight': 'final_norm.weight',
78            'transformer.ln_f.bias': 'final_norm.bias',
79            'lm_head.weight': 'lm_head.weight'
80        }

Mapping (hf: ours ) of decoder layers

83        for i in range(12):
84            mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
85            mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
86            mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
87            mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
88            mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
89            mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
90            mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
91            mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
92            mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
93            mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
94            mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
95            mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'

Move the parameters based on mapping

98        new_state_dict = {}
99        for old_key, new_key in mapping.items():
100            if old_key in state_dict:
101                new_state_dict[new_key] = state_dict[old_key]

GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers

104        convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
105                        [f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
106                        [f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
107                        [f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
108
109        for layer in convo_layers:
110            new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

Load out model. We use strict = False because the state does not have LoRA weights

113        missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)

make sure that only lora weights are not loaded

116        assert all('lora' in key for key in missing_keys)
117        assert not unexpected_keys

Initialize the model, optimizer and dataloader

119    def initialize(self):

Initialize the GPT2 model

124        self.model = GPTModel(
125            layer_norm_epsilon=self.layer_norm_epsilon,
126            d_model=self.d_model,
127            n_layers=self.n_layers,
128            n_heads=self.n_heads,
129            n_positions=self.n_positions,
130            vocab_size=self.vocab_size,
131            r=self.lora_r,
132        )
133        self.model.to(self.device)

Load pre-trained model weights

135        self._load_pretrained_weights()

Initialize the optimizer

138        self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)

Initialize the data loader

141        self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)

Training loop

143    def run(self):
148        for _ in monit.loop(self.epochs):

inputs has shape [batch_size, seq_len]

150            for (inputs,) in monit.iterate('Train', self.data_loader):

Move inputs to device

152                inputs = inputs.to(self.device)

Call the model, with the all but the last token

154                logits = self.model(inputs[:, :-1])

Get cross entropy loss

156                loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))

Make gradients 0

159                self.optimizer.zero_grad()

Compute gradients

161                loss.backward()

Optimize

163                self.optimizer.step()

Log the loss

166                tracker.save({'loss': loss})
167                tracker.add_global_step()

169            tracker.new_line()

Tiny Shakespeare dataset

It will download from the url if not present

172@option(Trainer.text)
173def tiny_shakespeare(c: Trainer):
179    path = lab.get_data_path() / 'tiny_shakespeare.txt'
180    if not path.exists():
181        download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
182    with open(path, 'r', encoding='utf-8') as f:
183        text = f.read()
184
185    tokens = c.tokenizer.encode(text)
186    num_batches = len(tokens) // (c.batch_size * c.context_len)
187    tokens = tokens[:num_batches * c.batch_size * c.context_len]
188    input_ids = torch.tensor(tokens).view(-1, c.context_len)
189    return TensorDataset(input_ids)