This is a tutorial/implementation of OpenAI GPT architecture in PyTorch. We got a bunch of implementation details from minGPT by @karpathy. This implementation also uses character tiny shakespeare dataset.
GPT model is essentially a standard transformer with a few tweaks. GPT-2 and especially GPT-3 models are quite large and won't fit on a single GPU and will need model parallelism. This implementation doesn't even use data parallelism and is intended to be more of a tutorial.
Main differences of this compared to a simple autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the existing labml/nn transformer implementation.
Here's a notebook for training a GPT model on Tiny Shakespeare dataset.
34import torch
35from torch import nn
36
37from labml import experiment
38from labml.configs import option
39from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
40from labml_nn.optimizers.configs import OptimizerConfigs
41from labml_nn.transformers import TransformerConfigs, Encoder
42from labml_nn.transformers.utils import subsequent_mask
This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
45class GPT(nn.Module):
encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.53 def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
60 super().__init__()
61 self.src_embed = src_embed
62 self.encoder = encoder
63 self.generator = generator
The mask will be initialized on the first call
66 self.mask = None
68 def forward(self, x: torch.Tensor):
Create subsequent mask if mask is not initialized or if the size of the mask is different
71 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
73 self.mask = subsequent_mask(len(x)).to(x.device)
Get the token embeddings with positional encodings
75 x = self.src_embed(x)
Transformer encoder
77 x = self.encoder(x, self.mask)
Get logits
79 x = self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
83 return x, None
86class Configs(NLPAutoRegressionConfigs):
GPT model
95 model: GPT
Transformer
97 transformer: TransformerConfigs
Weight decay
99 weight_decay: float = 0.1
Number of tokens for wamup
101 warmup_steps: int = 128 * 128 * 20
Custom optimizer
104 optimizer = 'transformer_optimizer'
107@option(Configs.transformer, 'GPT')
108def _transformer_configs(c: Configs):
We use our configurable transformer implementation
115 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
117 conf.n_src_vocab = c.n_tokens
118 conf.n_tgt_vocab = c.n_tokens
GPT uses GELU activation for position wise feedforward
120 conf.ffn.activation = 'GELU'
123 return conf
Weights of linear layers and embedding layers are initialized to instead of the default Xavier initialzation.
126def _init_weights(module):
135 if not isinstance(module, (nn.Linear, nn.Embedding)):
136 return
137
138 module.weight.data.normal_(mean=0.0, std=0.02)
Initialize biases to
141 if isinstance(module, nn.Linear) and module.bias is not None:
142 module.bias.data.zero_()
Create GPT model and initialize weights
145@option(Configs.model)
146def _model(c: Configs):
150 m = GPT(c.transformer.encoder,
151 c.transformer.src_embed,
152 c.transformer.generator).to(c.device)
Apply custom weight initialization
155 m.apply(_init_weights)
156
157 return m
This code is taken from minGPT. This applies weight decay only to weights of linear layers.
160@option(NLPAutoRegressionConfigs.optimizer)
161def transformer_optimizer(c: NLPAutoRegressionConfigs):
Collect names of parameters to apply weight decay
169 decay = set()
170 for mn, m in c.model.named_modules():
171 for pn, p in m.named_parameters():
172 fpn = f'{mn}.{pn}' if mn else pn # full param name
173
174 if fpn.endswith('weight') and isinstance(m, nn.Linear):
175 decay.add(fpn)
Get all the parameters
178 param_dict = {pn: p for pn, p in c.model.named_parameters()}
Parameters that are not decayed
180 no_decay = set(param_dict.keys()) - decay
create the pytorch optimizer object
183 opt_groups = [
184 {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": c.weight_decay},
185 {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
186 ]
Create a configurable optimizer, so that we can change these simply by passing a config dictionary.
191 optimizer = OptimizerConfigs()
Set parameter groups for optimization.
194 optimizer.parameters = opt_groups
Use cosine decay optimizer. This is what GPT uses.
197 optimizer.optimizer = 'AdamWarmupCosineDecay'
Set model embedding size, required if we use Noam optimizer which has an exponential decay.
200 optimizer.d_model = c.d_model
Set default weight decay. This is not required since we set the weight decay in the parameter groups.
203 optimizer.weight_decay = c.weight_decay
GPT uses a maximum learning rate of .
205 optimizer.learning_rate = 6e-4
207 optimizer.betas = (0.9, 0.95)
209 optimizer.eps = 1e-8
Weight decay is decoupled from gradients
211 optimizer.weight_decouple = True
Total number of optimization steps for learning rate cosine decay
213 optimizer.total_steps = c.epochs * len(c.text.train) // (c.batch_size * c.seq_len)
Number of warmup optimization steps
215 optimizer.warmup = c.warmup_steps // (c.batch_size * c.seq_len)
216
217 return optimizer
220def main():
Create experiment
222 experiment.create(name="gpt")
Create configs
224 conf = Configs()
Override configurations
226 experiment.configs(conf, {
Use character level tokenizer
228 'tokenizer': 'character',
Prompt separator is blank
230 'prompt_separator': '',
Starting prompt for sampling
232 'prompt': 'It is ',
Use Tiny Shakespeare dataset
234 'text': 'tiny_shakespeare',
Use a context size of
237 'seq_len': 128,
Train for epochs
239 'epochs': 32,
Batch size
241 'batch_size': 128,
Switch between training and validation for times per epoch
244 'inner_iterations': 10,
Transformer configurations
247 'transformer.d_model': 512,
248 'transformer.ffn.d_ff': 2048,
249 'transformer.n_heads': 8,
250 'transformer.n_layers': 6
251 })
Set models for saving and loading
254 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
257 with experiment.start():
Run training
259 conf.run()
263if __name__ == '__main__':
264 main()