Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
14import torch
15from torch.optim import Adam
16from torch.utils.data import DataLoader, TensorDataset
17from transformers import AutoTokenizer, AutoModelForCausalLM
18
19from labml import lab, monit, tracker
20from labml.configs import BaseConfigs, option
21from labml.utils.download import download_file
22from labml_helpers.device import DeviceConfigs
23from labml_nn.lora.gpt2 import GPTModel
The default configs can and will be over-ridden when we start the experiment
26class Trainer(BaseConfigs):
32 device: torch.device = DeviceConfigs()
GPT-2 configs
35 layer_norm_epsilon: float = 1e-05
36 d_model: int = 768
37 n_layers: int = 12
38 n_heads: int = 12
39 n_positions: int = 1024
40 vocab_size: int = 50257
Training configs
43 epochs: int = 10
44 batch_size: int = 32
45 learning_rate: float = 1e-4
46 context_len: int = 512
LoRA rank
49 lora_r: int = 32
Dataset
52 text: TensorDataset = "tiny_shakespeare"
Huggingface tokenizer
54 tokenizer = AutoTokenizer.from_pretrained("gpt2")
56 model: GPTModel
Optimizer
58 optimizer: torch.optim.Adam
Cross entropy loss
60 loss_func = torch.nn.CrossEntropyLoss()
Dataloader
62 data_loader: DataLoader
64 def _load_pretrained_weights(self):
Load the huggingface model and get the parameters
70 hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
71 state_dict = hf_model.state_dict()
Transformer embedding and prediction layer parameter mapping (hf: ours
)
74 mapping = {
75 'transformer.wte.weight': 'token_embedding.weight',
76 'transformer.wpe.weight': 'position_embedding.weight',
77 'transformer.ln_f.weight': 'final_norm.weight',
78 'transformer.ln_f.bias': 'final_norm.bias',
79 'lm_head.weight': 'lm_head.weight'
80 }
Mapping (hf: ours
) of decoder layers
83 for i in range(12):
84 mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
85 mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
86 mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
87 mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
88 mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
89 mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
90 mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
91 mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
92 mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
93 mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
94 mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
95 mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'
Move the parameters based on mapping
98 new_state_dict = {}
99 for old_key, new_key in mapping.items():
100 if old_key in state_dict:
101 new_state_dict[new_key] = state_dict[old_key]
GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
104 convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
105 [f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
106 [f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
107 [f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
108
109 for layer in convo_layers:
110 new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
Load out model. We use strict = False
because the state does not have LoRA weights
113 missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
make sure that only lora weights are not loaded
116 assert all('lora' in key for key in missing_keys)
117 assert not unexpected_keys
119 def initialize(self):
Initialize the GPT2 model
124 self.model = GPTModel(
125 layer_norm_epsilon=self.layer_norm_epsilon,
126 d_model=self.d_model,
127 n_layers=self.n_layers,
128 n_heads=self.n_heads,
129 n_positions=self.n_positions,
130 vocab_size=self.vocab_size,
131 r=self.lora_r,
132 )
133 self.model.to(self.device)
Load pre-trained model weights
135 self._load_pretrained_weights()
Initialize the optimizer
138 self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
Initialize the data loader
141 self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
143 def run(self):
148 for _ in monit.loop(self.epochs):
inputs
has shape [batch_size, seq_len]
150 for (inputs,) in monit.iterate('Train', self.data_loader):
Move inputs
to device
152 inputs = inputs.to(self.device)
Call the model, with the all but the last token
154 logits = self.model(inputs[:, :-1])
Get cross entropy loss
156 loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))
Make gradients 0
159 self.optimizer.zero_grad()
Compute gradients
161 loss.backward()
Optimize
163 self.optimizer.step()
Log the loss
166 tracker.save({'loss': loss})
167 tracker.add_global_step()
169 tracker.new_line()
172@option(Trainer.text)
173def tiny_shakespeare(c: Trainer):
179 path = lab.get_data_path() / 'tiny_shakespeare.txt'
180 if not path.exists():
181 download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
182 with open(path, 'r', encoding='utf-8') as f:
183 text = f.read()
184
185 tokens = c.tokenizer.encode(text)
186 num_batches = len(tokens) // (c.batch_size * c.context_len)
187 tokens = tokens[:num_batches * c.batch_size * c.context_len]
188 input_ids = torch.tensor(tokens).view(-1, c.context_len)
189 return TensorDataset(input_ids)