Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
14import torch
15from torch.optim import Adam
16from torch.utils.data import DataLoader, TensorDataset
17from transformers import AutoTokenizer, AutoModelForCausalLM
18
19from labml import lab, monit, tracker
20from labml.configs import BaseConfigs, option
21from labml.utils.download import download_file
22from labml_nn.helpers.device import DeviceConfigs
23from labml_nn.lora.gpt2 import GPTModelThe default configs can and will be over-ridden when we start the experiment
26class Trainer(BaseConfigs):32    device: torch.device = DeviceConfigs()GPT-2 configs
35    layer_norm_epsilon: float = 1e-05
36    d_model: int = 768
37    n_layers: int = 12
38    n_heads: int = 12
39    n_positions: int = 1024
40    vocab_size: int = 50257Training configs
43    epochs: int = 10
44    batch_size: int = 32
45    learning_rate: float = 1e-4
46    context_len: int = 512LoRA rank
49    lora_r: int = 32Dataset
52    text: TensorDataset = "tiny_shakespeare"Huggingface tokenizer
54    tokenizer = AutoTokenizer.from_pretrained("gpt2")56    model: GPTModelOptimizer
58    optimizer: torch.optim.AdamCross entropy loss
60    loss_func = torch.nn.CrossEntropyLoss()Dataloader
62    data_loader: DataLoader64    def _load_pretrained_weights(self):Load the huggingface model and get the parameters
70        hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
71        state_dict = hf_model.state_dict()Transformer embedding and prediction layer parameter mapping (hf: ours
) 
74        mapping = {
75            'transformer.wte.weight': 'token_embedding.weight',
76            'transformer.wpe.weight': 'position_embedding.weight',
77            'transformer.ln_f.weight': 'final_norm.weight',
78            'transformer.ln_f.bias': 'final_norm.bias',
79            'lm_head.weight': 'lm_head.weight'
80        }Mapping (hf: ours
) of decoder layers 
83        for i in range(12):
84            mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
85            mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
86            mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
87            mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
88            mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
89            mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
90            mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
91            mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
92            mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
93            mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
94            mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
95            mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'Move the parameters based on mapping
98        new_state_dict = {}
99        for old_key, new_key in mapping.items():
100            if old_key in state_dict:
101                new_state_dict[new_key] = state_dict[old_key]GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
104        convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
105                        [f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
106                        [f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
107                        [f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
108
109        for layer in convo_layers:
110            new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)Load out model. We use strict = False
 because the state does not have LoRA weights 
113        missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)make sure that only lora weights are not loaded
116        assert all('lora' in key for key in missing_keys)
117        assert not unexpected_keys119    def initialize(self):Initialize the GPT2 model
124        self.model = GPTModel(
125            layer_norm_epsilon=self.layer_norm_epsilon,
126            d_model=self.d_model,
127            n_layers=self.n_layers,
128            n_heads=self.n_heads,
129            n_positions=self.n_positions,
130            vocab_size=self.vocab_size,
131            r=self.lora_r,
132        )
133        self.model.to(self.device)Load pre-trained model weights
135        self._load_pretrained_weights()Initialize the optimizer
138        self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)Initialize the data loader
141        self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)143    def run(self):148        for _ in monit.loop(self.epochs):inputs
 has shape [batch_size, seq_len]
 
150            for (inputs,) in monit.iterate('Train', self.data_loader):Move inputs
 to device 
152                inputs = inputs.to(self.device)Call the model, with the all but the last token
154                logits = self.model(inputs[:, :-1])Get cross entropy loss
156                loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))Make gradients 0
159                self.optimizer.zero_grad()Compute gradients
161                loss.backward()Optimize
163                self.optimizer.step()Log the loss
166                tracker.save({'loss': loss})
167                tracker.add_global_step()169            tracker.new_line()172@option(Trainer.text)
173def tiny_shakespeare(c: Trainer):179    path = lab.get_data_path() / 'tiny_shakespeare.txt'
180    if not path.exists():
181        download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
182    with open(path, 'r', encoding='utf-8') as f:
183        text = f.read()
184
185    tokens = c.tokenizer.encode(text)
186    num_batches = len(tokens) // (c.batch_size * c.context_len)
187    tokens = tokens[:num_batches * c.batch_size * c.context_len]
188    input_ids = torch.tensor(tokens).view(-1, c.context_len)
189    return TensorDataset(input_ids)