GPT

This is a tutorial/implementation of OpenAI GPT architecture in PyTorch. We got a bunch of implementation details from minGPT by @karpathy. This implementation also uses character tiny shakespeare dataset.

GPT model is essentially a standard transformer with a few tweaks. GPT-2 and especially GPT-3 models are quite large and won't fit on a single GPU and will need model parallelism. This implementation doesn't even use data parallelism and is intended to be more of a tutorial.

Main differences of this compared to a simple autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the existing labml/nn transformer implementation.

Here's a notebook for training a GPT model on Tiny Shakespeare dataset.

Open In Colab View Run

35import torch
36from torch import nn
37
38from labml import experiment
39from labml.configs import option
40from labml_helpers.module import Module
41from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
42from labml_nn.optimizers.configs import OptimizerConfigs
43from labml_nn.transformers import TransformerConfigs, Encoder
44from labml_nn.transformers.utils import subsequent_mask

GPT model

This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.

47class GPT(Module):
55    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
62        super().__init__()
63        self.src_embed = src_embed
64        self.encoder = encoder
65        self.generator = generator

The mask will be initialized on the first call

68        self.mask = None
70    def forward(self, x: torch.Tensor):

Create subsequent mask if mask is not initialized or if the size of the mask is different

73        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

75            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings with positional encodings

77        x = self.src_embed(x)

Transformer encoder

79        x = self.encoder(x, self.mask)

Get logits

81        x = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

85        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

88class Configs(NLPAutoRegressionConfigs):

GPT model

97    model: GPT

Transformer

99    transformer: TransformerConfigs

Weight decay

101    weight_decay: float = 0.1

Number of tokens for wamup

103    warmup_steps: int = 128 * 128 * 20

Custom optimizer

106    optimizer = 'transformer_optimizer'

Transformer configurations

109@option(Configs.transformer, 'GPT')
110def _transformer_configs(c: Configs):
117    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

119    conf.n_src_vocab = c.n_tokens
120    conf.n_tgt_vocab = c.n_tokens

GPT uses GELU activation for position wise feedforward

122    conf.ffn.activation = 'GELU'

125    return conf

Initialize weights

Weights of linear layers and embedding layers are initialized to instead of the default Xavier initialzation.

128def _init_weights(module):
137    if not isinstance(module, (nn.Linear, nn.Embedding)):
138        return
139
140    module.weight.data.normal_(mean=0.0, std=0.02)

Initialize biases to

143    if isinstance(module, nn.Linear) and module.bias is not None:
144        module.bias.data.zero_()

Create GPT model and initialize weights

147@option(Configs.model)
148def _model(c: Configs):
152    m = GPT(c.transformer.encoder,
153            c.transformer.src_embed,
154            c.transformer.generator).to(c.device)

Apply custom weight initialization

157    m.apply(_init_weights)
158
159    return m

Create custom optimizer with weight decay

This code is taken from minGPT. This applies weight decay only to weights of linear layers.

162@option(NLPAutoRegressionConfigs.optimizer)
163def transformer_optimizer(c: NLPAutoRegressionConfigs):

Collect names of parameters to apply weight decay

171    decay = set()
172    for mn, m in c.model.named_modules():
173        for pn, p in m.named_parameters():
174            fpn = f'{mn}.{pn}' if mn else pn  # full param name
175
176            if fpn.endswith('weight') and isinstance(m, nn.Linear):
177                decay.add(fpn)

Get all the parameters

180    param_dict = {pn: p for pn, p in c.model.named_parameters()}

Parameters that are not decayed

182    no_decay = set(param_dict.keys()) - decay

create the pytorch optimizer object

185    opt_groups = [
186        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": c.weight_decay},
187        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
188    ]

Create a configurable optimizer, so that we can change these simply by passing a config dictionary.

193    optimizer = OptimizerConfigs()

Set parameter groups for optimization.

196    optimizer.parameters = opt_groups

Use cosine decay optimizer. This is what GPT uses.

199    optimizer.optimizer = 'AdamWarmupCosineDecay'

Set model embedding size, required if we use Noam optimizer which has an exponential decay.

202    optimizer.d_model = c.d_model

Set default weight decay. This is not required since we set the weight decay in the parameter groups.

205    optimizer.weight_decay = c.weight_decay

GPT uses a maximum learning rate of .

207    optimizer.learning_rate = 6e-4

209    optimizer.betas = (0.9, 0.95)

211    optimizer.eps = 1e-8

Weight decay is decoupled from gradients

213    optimizer.weight_decouple = True

Total number of optimization steps for learning rate cosine decay

215    optimizer.total_steps = c.epochs * len(c.text.train) // (c.batch_size * c.seq_len)

Number of warmup optimization steps

217    optimizer.warmup = c.warmup_steps // (c.batch_size * c.seq_len)
218
219    return optimizer
222def main():

Create experiment

224    experiment.create(name="gpt")

Create configs

226    conf = Configs()

Override configurations

228    experiment.configs(conf, {

Use character level tokenizer

230        'tokenizer': 'character',

Prompt separator is blank

232        'prompt_separator': '',

Starting prompt for sampling

234        'prompt': 'It is ',

Use Tiny Shakespeare dataset

236        'text': 'tiny_shakespeare',

Use a context size of

239        'seq_len': 128,

Train for epochs

241        'epochs': 32,

Batch size

243        'batch_size': 128,

Switch between training and validation for times per epoch

246        'inner_iterations': 10,

Transformer configurations

249        'transformer.d_model': 512,
250        'transformer.ffn.d_ff': 2048,
251        'transformer.n_heads': 8,
252        'transformer.n_layers': 6
253    })

Set models for saving and loading

256    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

259    with experiment.start():

Run training

261        conf.run()

265if __name__ == '__main__':
266    main()