Transformer Auto-Regression Experiment with Sophia-G optimizer

This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset) with Sophia-G optimizer.

13import torch
14
15from labml import experiment, tracker
16from labml_helpers.train_valid import BatchIndex
17from labml_nn.optimizers.sophia import Sophia
18from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs

Configurations

This inherits from Configs

21class Configs(TransformerAutoRegressionConfigs):
28    hess_interval: int = 10
29
30    optimizer: Sophia

Training or validation step with Gauss-Newton-Bartlett (GNB) Hessian diagonal estimator

32    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

38        self.model.train(self.mode.is_train)

Move data to the device

41        data, target = batch[0].to(self.device), batch[1].to(self.device)

Estimate the Hessian diagonal every steps

44        if isinstance(self.optimizer, Sophia) and self.mode.is_train and batch_idx.idx % self.hess_interval == 0:

Get model outputs

46            output, *_ = self.model(data)

Create a categorical distribution from logits

49            samp_dist = torch.distributions.Categorical(logits=output)

Sample

51            y_sample = samp_dist.sample()

Calculate and log loss

54            loss = self.loss_func(output, y_sample)
55            tracker.add("loss.hess.", loss)

Calculate gradients

58            loss.backward()

Clip gradients

60            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Update EMA Hessian diagonal

67            self.optimizer.update_hessian(data.numel())

Clear the gradients

69            self.optimizer.zero_grad()
70        else:

Move data to the device

72            data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

75            if self.mode.is_train:
76                tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

79            with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

83                output, *_ = self.model(data)

Calculate and log loss

86            loss = self.loss_func(output, target)
87            tracker.add("loss.", loss)

Calculate and log accuracy

90            self.accuracy(output, target)
91            self.accuracy.track()
92
93            self.other_metrics(output, target)

Train the model

96            if self.mode.is_train:

Calculate gradients

98                loss.backward()

Clip gradients

100                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

102                self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

104                if batch_idx.is_last and self.is_log_model_params_grads:
105                    tracker.add('model', self.model)

Clear the gradients

107                self.optimizer.zero_grad()

Save the tracked metrics

110            tracker.save()
113def main():

Create experiment

115    experiment.create(name="transformer")

Create configs

117    conf = Configs()

Override configurations

119    experiment.configs(conf, {

Use character level tokenizer

121        'tokenizer': 'character',

Prompt separator is blank

123        'prompt_separator': '',

Starting prompt for sampling

125        'prompt': 'It is ',

Use Tiny Shakespeare dataset

127        'text': 'tiny_shakespeare',

Use a context size of

130        'seq_len': 512,

Train for 32 epochs

132        'epochs': 32,

Batch size

134        'batch_size': 16,

Switch between training and validation for times per epoch

137        'inner_iterations': 10,

Model size

140        'd_model': 256,
141        'transformer.n_heads': 16,
142        'transformer.ffn.d_ff': 1024,
145        'optimizer.optimizer': 'Sophia',
146        'optimizer.learning_rate': 3e-4,
147        'optimizer.rho': 0.03,
148    })

Set models for saving and loading

151    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

154    with experiment.start():

Run training

156        conf.run()

160if __name__ == '__main__':
161    main()