NLP model trainer for classification

11from collections import Counter
12from typing import Callable
14import torch
15import torchtext
16from torch import nn
17from import DataLoader
18from torchtext.vocab import Vocab
20from labml import lab, tracker, monit
21from labml.configs import option
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.metrics.accuracy import Accuracy
24from labml_helpers.module import Module
25from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
26from labml_nn.optimizers.configs import OptimizerConfigs

Trainer configurations

This has the basic configurations for NLP classification task training. All the properties are configurable.

29class NLPClassificationConfigs(TrainValidConfigs):


40    optimizer: torch.optim.Adam

Training device

42    device: torch.device = DeviceConfigs()

Autoregressive model

45    model: Module

Batch size

47    batch_size: int = 16

Length of the sequence, or context size

49    seq_len: int = 512


51    vocab: Vocab = 'ag_news'

Number of token in vocabulary

53    n_tokens: int

Number of classes

55    n_classes: int = 'ag_news'


57    tokenizer: Callable = 'character'

Whether to periodically save models

60    is_save_models = True

Loss function

63    loss_func = nn.CrossEntropyLoss()

Accuracy function

65    accuracy = Accuracy()

Model embedding size

67    d_model: int = 512

Gradient clipping

69    grad_norm_clip: float = 1.0

Training data loader

72    train_loader: DataLoader = 'ag_news'

Validation data loader

74    valid_loader: DataLoader = 'ag_news'


76    def init(self):

Set tracker configurations

81        tracker.set_scalar("accuracy.*", True)
82        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

84        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

89        self.state_modules = [self.accuracy]

Training or validation step

91    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

97        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

100        if self.mode.is_train:
101            tracker.add_global_step(data.shape[1])

Whether to capture model outputs

104        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜

108            output, *_ = self.model(data)

Calculate and log loss

111        loss = self.loss_func(output, target)
112        tracker.add("loss.", loss)

Calculate and log accuracy

115        self.accuracy(output, target)
116        self.accuracy.track()

Train the model

119        if self.mode.is_train:

Calculate gradients

121            loss.backward()

Clip gradients

123            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

125            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

127            if batch_idx.is_last:
128                tracker.add('model', self.model)

Clear the gradients

130            self.optimizer.zero_grad()

Save the tracked metrics

137def _optimizer(c: NLPClassificationConfigs):
142    optimizer = OptimizerConfigs()
143    optimizer.parameters = c.model.parameters()
144    optimizer.optimizer = 'Adam'
145    optimizer.d_model = c.d_model
147    return optimizer

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

    'tokenizer': 'basic_english',

as the configurations dictionary when starting the experiment.

151def basic_english():
165    from import get_tokenizer
166    return get_tokenizer('basic_english')

Character level tokenizer

169def character_tokenizer(x: str):
173    return list(x)

Character level tokenizer configuration

177def character():
181    return character_tokenizer

Get number of tokens

185def _n_tokens(c: NLPClassificationConfigs):
189    return len(c.vocab) + 2

Function to load data into batches

192class CollateFunc:
  • tokenizer is the tokenizer function
  • vocab is the vocabulary
  • seq_len is the length of the sequence
  • padding_token is the token used for padding when the seq_len is larger than the text length
  • classifier_token is the [CLS] token which we set at end of the input
197    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
205        self.classifier_token = classifier_token
206        self.padding_token = padding_token
207        self.seq_len = seq_len
208        self.vocab = vocab
209        self.tokenizer = tokenizer
  • batch is the batch of data collected by the DataLoader
211    def __call__(self, batch):

Input data tensor, initialized with padding_token

217        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

Empty labels tensor

219        labels = torch.zeros(len(batch), dtype=torch.long)

Loop through the samples

222        for (i, (_label, _text)) in enumerate(batch):

Set the label

224            labels[i] = int(_label) - 1

Tokenize the input text

226            _text = [self.vocab[token] for token in self.tokenizer(_text)]

Truncate upto seq_len

228            _text = _text[:self.seq_len]

Transpose and add to data

230            data[:len(_text), i] = data.new_tensor(_text)

Set the final token in the sequence to [CLS]

233        data[-1, :] = self.classifier_token
236        return data, labels

AG News dataset

This loads the AG News dataset and the set the values for n_classes',vocab,train_loader, andvalid_loader`.

240         NLPClassificationConfigs.vocab,
241         NLPClassificationConfigs.train_loader,
242         NLPClassificationConfigs.valid_loader])
243def ag_news(c: NLPClassificationConfigs):

Get training and validation datasets

252    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

Load data to memory

255    with monit.section('Load data'):
256        from labml_nn.utils import MapStyleDataset
259        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

Get tokenizer

262    tokenizer = c.tokenizer

Create a counter

265    counter = Counter()

Collect tokens from training dataset

267    for (label, line) in train:
268        counter.update(tokenizer(line))

Collect tokens from validation dataset

270    for (label, line) in valid:
271        counter.update(tokenizer(line))

Create vocabulary

273    vocab = Vocab(counter, min_freq=1)

Create training data loader

276    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
277                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Create validation data loader

279    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
280                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Return n_classes',vocab,train_loader, andvalid_loader`

283    return 4, vocab, train_loader, valid_loader