NLP model trainer for classification

11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs

Trainer configurations

This has the basic configurations for NLP classification task training. All the properties are configurable.

30class NLPClassificationConfigs(TrainValidConfigs):

Optimizer

41    optimizer: torch.optim.Adam

Training device

43    device: torch.device = DeviceConfigs()

Autoregressive model

46    model: Module

Batch size

48    batch_size: int = 16

Length of the sequence, or context size

50    seq_len: int = 512

Vocabulary

52    vocab: Vocab = 'ag_news'

Number of token in vocabulary

54    n_tokens: int

Number of classes

56    n_classes: int = 'ag_news'

Tokenizer

58    tokenizer: Callable = 'character'

Whether to periodically save models

61    is_save_models = True

Loss function

64    loss_func = nn.CrossEntropyLoss()

Accuracy function

66    accuracy = Accuracy()

Model embedding size

68    d_model: int = 512

Gradient clipping

70    grad_norm_clip: float = 1.0

Training data loader

73    train_loader: DataLoader = 'ag_news'

Validation data loader

75    valid_loader: DataLoader = 'ag_news'

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

80    is_log_model_params_grads: bool = False

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

85    is_log_model_activations: bool = False

Initialization

87    def init(self):

Set tracker configurations

92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

95        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

100        self.state_modules = [self.accuracy]

Training or validation step

102    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

108        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])

Whether to capture model outputs

115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

119            output, *_ = self.model(data)

Calculate and log loss

122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)

Calculate and log accuracy

126        self.accuracy(output, target)
127        self.accuracy.track()

Train the model

130        if self.mode.is_train:

Calculate gradients

132            loss.backward()

Clip gradients

134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

136            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)

Clear the gradients

141            self.optimizer.zero_grad()

Save the tracked metrics

144        tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizer

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')

Character level tokenizer

180def character_tokenizer(x: str):
184    return list(x)

Character level tokenizer configuration

187@option(NLPClassificationConfigs.tokenizer)
188def character():
192    return character_tokenizer

Get number of tokens

195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200    return len(c.vocab) + 2

Function to load data into batches

203class CollateFunc:
  • tokenizer is the tokenizer function
  • vocab is the vocabulary
  • seq_len is the length of the sequence
  • padding_token is the token used for padding when the seq_len is larger than the text length
  • classifier_token is the [CLS] token which we set at end of the input
208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizer
  • batch is the batch of data collected by the DataLoader
222    def __call__(self, batch):

Input data tensor, initialized with padding_token

228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

Empty labels tensor

230        labels = torch.zeros(len(batch), dtype=torch.long)

Loop through the samples

233        for (i, (_label, _text)) in enumerate(batch):

Set the label

235            labels[i] = int(_label) - 1

Tokenize the input text

237            _text = [self.vocab[token] for token in self.tokenizer(_text)]

Truncate upto seq_len

239            _text = _text[:self.seq_len]

Transpose and add to data

241            data[:len(_text), i] = data.new_tensor(_text)

Set the final token in the sequence to [CLS]

244        data[-1, :] = self.classifier_token

247        return data, labels

AG News dataset

This loads the AG News dataset and the set the values for n_classes , vocab , train_loader , and valid_loader .

250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):

Get training and validation datasets

263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

Load data to memory

266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDataset
270        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

Get tokenizer

273    tokenizer = c.tokenizer

Create a counter

276    counter = Counter()

Collect tokens from training dataset

278    for (label, line) in train:
279        counter.update(tokenizer(line))

Collect tokens from validation dataset

281    for (label, line) in valid:
282        counter.update(tokenizer(line))

Create vocabulary

284    vocab = torchtext.vocab.vocab(counter, min_freq=1)

Create training data loader

287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Create validation data loader

290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Return n_classes , vocab , train_loader , and valid_loader

294    return 4, vocab, train_loader, valid_loader