NLP model trainer for classification

11from collections import Counter
12from typing import Callable
13
14import torchtext
15import torchtext.vocab
16from torchtext.vocab import Vocab
17
18import torch
19from labml import lab, tracker, monit
20from labml.configs import option
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.metrics import Accuracy
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
24from labml_nn.optimizers.configs import OptimizerConfigs
25from torch import nn
26from torch.utils.data import DataLoader

Trainer configurations

This has the basic configurations for NLP classification task training. All the properties are configurable.

29class NLPClassificationConfigs(TrainValidConfigs):

Optimizer

40    optimizer: torch.optim.Adam

Training device

42    device: torch.device = DeviceConfigs()

Autoregressive model

45    model: nn.Module

Batch size

47    batch_size: int = 16

Length of the sequence, or context size

49    seq_len: int = 512

Vocabulary

51    vocab: Vocab = 'ag_news'

Number of token in vocabulary

53    n_tokens: int

Number of classes

55    n_classes: int = 'ag_news'

Tokenizer

57    tokenizer: Callable = 'character'

Whether to periodically save models

60    is_save_models = True

Loss function

63    loss_func = nn.CrossEntropyLoss()

Accuracy function

65    accuracy = Accuracy()

Model embedding size

67    d_model: int = 512

Gradient clipping

69    grad_norm_clip: float = 1.0

Training data loader

72    train_loader: DataLoader = 'ag_news'

Validation data loader

74    valid_loader: DataLoader = 'ag_news'

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

79    is_log_model_params_grads: bool = False

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

84    is_log_model_activations: bool = False

Initialization

86    def init(self):

Set tracker configurations

91        tracker.set_scalar("accuracy.*", True)
92        tracker.set_scalar("loss.*", True)

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

97        self.state_modules = [self.accuracy]

Training or validation step

99    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

105        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

108        if self.mode.is_train:
109            tracker.add_global_step(data.shape[1])

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

114        output, *_ = self.model(data)

Calculate and log loss

117        loss = self.loss_func(output, target)
118        tracker.add("loss.", loss)

Calculate and log accuracy

121        self.accuracy(output, target)
122        self.accuracy.track()

Train the model

125        if self.mode.is_train:

Calculate gradients

127            loss.backward()

Clip gradients

129            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

131            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

133            if batch_idx.is_last and self.is_log_model_params_grads:
134                tracker.add('model', self.model)

Clear the gradients

136            self.optimizer.zero_grad()

Save the tracked metrics

139        tracker.save()
142@option(NLPClassificationConfigs.optimizer)
143def _optimizer(c: NLPClassificationConfigs):
148    optimizer = OptimizerConfigs()
149    optimizer.parameters = c.model.parameters()
150    optimizer.optimizer = 'Adam'
151    optimizer.d_model = c.d_model
152
153    return optimizer

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

156@option(NLPClassificationConfigs.tokenizer)
157def basic_english():
171    from torchtext.data import get_tokenizer
172    return get_tokenizer('basic_english')

Character level tokenizer

175def character_tokenizer(x: str):
179    return list(x)

Character level tokenizer configuration

182@option(NLPClassificationConfigs.tokenizer)
183def character():
187    return character_tokenizer

Get number of tokens

190@option(NLPClassificationConfigs.n_tokens)
191def _n_tokens(c: NLPClassificationConfigs):
195    return len(c.vocab) + 2

Function to load data into batches

198class CollateFunc:
  • tokenizer is the tokenizer function
  • vocab is the vocabulary
  • seq_len is the length of the sequence
  • padding_token is the token used for padding when the seq_len is larger than the text length
  • classifier_token is the [CLS] token which we set at end of the input
203    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
211        self.classifier_token = classifier_token
212        self.padding_token = padding_token
213        self.seq_len = seq_len
214        self.vocab = vocab
215        self.tokenizer = tokenizer
  • batch is the batch of data collected by the DataLoader
217    def __call__(self, batch):

Input data tensor, initialized with padding_token

223        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

Empty labels tensor

225        labels = torch.zeros(len(batch), dtype=torch.long)

Loop through the samples

228        for (i, (_label, _text)) in enumerate(batch):

Set the label

230            labels[i] = int(_label) - 1

Tokenize the input text

232            _text = [self.vocab[token] for token in self.tokenizer(_text)]

Truncate upto seq_len

234            _text = _text[:self.seq_len]

Transpose and add to data

236            data[:len(_text), i] = data.new_tensor(_text)

Set the final token in the sequence to [CLS]

239        data[-1, :] = self.classifier_token

242        return data, labels

AG News dataset

This loads the AG News dataset and the set the values for n_classes , vocab , train_loader , and valid_loader .

245@option([NLPClassificationConfigs.n_classes,
246         NLPClassificationConfigs.vocab,
247         NLPClassificationConfigs.train_loader,
248         NLPClassificationConfigs.valid_loader])
249def ag_news(c: NLPClassificationConfigs):

Get training and validation datasets

258    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

Load data to memory

261    with monit.section('Load data'):
262        from labml_nn.utils import MapStyleDataset
265        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

Get tokenizer

268    tokenizer = c.tokenizer

Create a counter

271    counter = Counter()

Collect tokens from training dataset

273    for (label, line) in train:
274        counter.update(tokenizer(line))

Collect tokens from validation dataset

276    for (label, line) in valid:
277        counter.update(tokenizer(line))

Create vocabulary

279    vocab = torchtext.vocab.vocab(counter, min_freq=1)

Create training data loader

282    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
283                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Create validation data loader

285    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
286                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Return n_classes , vocab , train_loader , and valid_loader

289    return 4, vocab, train_loader, valid_loader