GPT

This is a tutorial/implementation of OpenAI GPT architecture in PyTorch. We got a bunch of implementation details from minGPT by @karpathy. This implementation also uses character tiny shakespeare dataset.

GPT model is essentially a standard transformer with a few tweaks. GPT-2 and especially GPT-3 models are quite large and won't fit on a single GPU and will need model parallelism. This implementation doesn't even use data parallelism and is intended to be more of a tutorial.

Main differences of this compared to a simple autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the existing labml/nn transformer implementation.

Here's a notebook for training a GPT model on Tiny Shakespeare dataset.

Open In Colab

34import torch
35from torch import nn
36
37from labml import experiment
38from labml.configs import option
39from labml_helpers.module import Module
40from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
41from labml_nn.optimizers.configs import OptimizerConfigs
42from labml_nn.transformers import TransformerConfigs, Encoder
43from labml_nn.transformers.utils import subsequent_mask

GPT model

This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.

46class GPT(Module):
54    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
61        super().__init__()
62        self.src_embed = src_embed
63        self.encoder = encoder
64        self.generator = generator

The mask will be initialized on the first call

67        self.mask = None
69    def forward(self, x: torch.Tensor):

Create subsequent mask if mask is not initialized or if the size of the mask is different

72        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

74            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings with positional encodings

76        x = self.src_embed(x)

Transformer encoder

78        x = self.encoder(x, self.mask)

Get logits

80        x = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

84        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

87class Configs(NLPAutoRegressionConfigs):

GPT model

96    model: GPT

Transformer

98    transformer: TransformerConfigs

Weight decay

100    weight_decay: float = 0.1

Number of tokens for wamup

102    warmup_steps: int = 128 * 128 * 20

Custom optimizer

105    optimizer = 'transformer_optimizer'

Transformer configurations

108@option(Configs.transformer, 'GPT')
109def _transformer_configs(c: Configs):
116    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

118    conf.n_src_vocab = c.n_tokens
119    conf.n_tgt_vocab = c.n_tokens

GPT uses GELU activation for position wise feedforward

121    conf.ffn.activation = 'GELU'

124    return conf

Initialize weights

Weights of linear layers and embedding layers are initialized to instead of the default Xavier initialzation.

127def _init_weights(module):
136    if not isinstance(module, (nn.Linear, nn.Embedding)):
137        return
138
139    module.weight.data.normal_(mean=0.0, std=0.02)

Initialize biases to

142    if isinstance(module, nn.Linear) and module.bias is not None:
143        module.bias.data.zero_()

Create GPT model and initialize weights

146@option(Configs.model)
147def _model(c: Configs):
151    m = GPT(c.transformer.encoder,
152            c.transformer.src_embed,
153            c.transformer.generator).to(c.device)

Apply custom weight initialization

156    m.apply(_init_weights)
157
158    return m

Create custom optimizer with weight decay

This code is taken from minGPT. This applies weight decay only to weights of linear layers.

161@option(NLPAutoRegressionConfigs.optimizer)
162def transformer_optimizer(c: NLPAutoRegressionConfigs):

Collect names of parameters to apply weight decay

170    decay = set()
171    for mn, m in c.model.named_modules():
172        for pn, p in m.named_parameters():
173            fpn = f'{mn}.{pn}' if mn else pn  # full param name
174
175            if fpn.endswith('weight') and isinstance(m, nn.Linear):
176                decay.add(fpn)

Get all the parameters

179    param_dict = {pn: p for pn, p in c.model.named_parameters()}

Parameters that are not decayed

181    no_decay = set(param_dict.keys()) - decay

create the pytorch optimizer object

184    opt_groups = [
185        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": c.weight_decay},
186        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
187    ]

Create a configurable optimizer, so that we can change these simply by passing a config dictionary.

192    optimizer = OptimizerConfigs()

Set parameter groups for optimization.

195    optimizer.parameters = opt_groups

Use cosine decay optimizer. This is what GPT uses.

198    optimizer.optimizer = 'AdamWarmupCosineDecay'

Set model embedding size, required if we use Noam optimizer which has an exponential decay.

201    optimizer.d_model = c.d_model

Set default weight decay. This is not required since we set the weight decay in the parameter groups.

204    optimizer.weight_decay = c.weight_decay

GPT uses a maximum learning rate of .

206    optimizer.learning_rate = 6e-4

208    optimizer.betas = (0.9, 0.95)

210    optimizer.eps = 1e-8

Weight decay is decoupled from gradients

212    optimizer.weight_decouple = True

Total number of optimization steps for learning rate cosine decay

214    optimizer.total_steps = c.epochs * len(c.text.train) // (c.batch_size * c.seq_len)

Number of warmup optimization steps

216    optimizer.warmup = c.warmup_steps // (c.batch_size * c.seq_len)
217
218    return optimizer
221def main():

Create experiment

223    experiment.create(name="gpt")

Create configs

225    conf = Configs()

Override configurations

227    experiment.configs(conf, {

Use character level tokenizer

229        'tokenizer': 'character',

Prompt separator is blank

231        'prompt_separator': '',

Starting prompt for sampling

233        'prompt': 'It is ',

Use Tiny Shakespeare dataset

235        'text': 'tiny_shakespeare',

Use a context size of

238        'seq_len': 128,

Train for epochs

240        'epochs': 32,

Batch size

242        'batch_size': 128,

Switch between training and validation for times per epoch

245        'inner_iterations': 10,

Transformer configurations

248        'transformer.d_model': 512,
249        'transformer.ffn.d_ff': 2048,
250        'transformer.n_heads': 8,
251        'transformer.n_layers': 6
252    })

Set models for saving and loading

255    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

258    with experiment.start():

Run training

260        conf.run()

264if __name__ == '__main__':
265    main()