11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs
This has the basic configurations for NLP classification task training. All the properties are configurable.
30class NLPClassificationConfigs(TrainValidConfigs):
Optimizer
41 optimizer: torch.optim.Adam
Training device
43 device: torch.device = DeviceConfigs()
Autoregressive model
46 model: Module
Batch size
48 batch_size: int = 16
Length of the sequence, or context size
50 seq_len: int = 512
Vocabulary
52 vocab: Vocab = 'ag_news'
Number of token in vocabulary
54 n_tokens: int
Number of classes
56 n_classes: int = 'ag_news'
Tokenizer
58 tokenizer: Callable = 'character'
Whether to periodically save models
61 is_save_models = True
Loss function
64 loss_func = nn.CrossEntropyLoss()
Accuracy function
66 accuracy = Accuracy()
Model embedding size
68 d_model: int = 512
Gradient clipping
70 grad_norm_clip: float = 1.0
Training data loader
73 train_loader: DataLoader = 'ag_news'
Validation data loader
75 valid_loader: DataLoader = 'ag_news'
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
80 is_log_model_params_grads: bool = False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
85 is_log_model_activations: bool = False
87 def init(self):
Set tracker configurations
92 tracker.set_scalar("accuracy.*", True)
93 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
95 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
100 self.state_modules = [self.accuracy]
102 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
108 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
111 if self.mode.is_train:
112 tracker.add_global_step(data.shape[1])
Whether to capture model outputs
115 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
119 output, *_ = self.model(data)
Calculate and log loss
122 loss = self.loss_func(output, target)
123 tracker.add("loss.", loss)
Calculate and log accuracy
126 self.accuracy(output, target)
127 self.accuracy.track()
Train the model
130 if self.mode.is_train:
Calculate gradients
132 loss.backward()
Clip gradients
134 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
136 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
138 if batch_idx.is_last and self.is_log_model_params_grads:
139 tracker.add('model', self.model)
Clear the gradients
141 self.optimizer.zero_grad()
Save the tracked metrics
144 tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153 optimizer = OptimizerConfigs()
154 optimizer.parameters = c.model.parameters()
155 optimizer.optimizer = 'Adam'
156 optimizer.d_model = c.d_model
157
158 return optimizer
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176 from torchtext.data import get_tokenizer
177 return get_tokenizer('basic_english')
180def character_tokenizer(x: str):
184 return list(x)
Character level tokenizer configuration
187@option(NLPClassificationConfigs.tokenizer)
188def character():
192 return character_tokenizer
Get number of tokens
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200 return len(c.vocab) + 2
203class CollateFunc:
tokenizer
is the tokenizer function vocab
is the vocabulary seq_len
is the length of the sequence padding_token
is the token used for padding when the seq_len
is larger than the text length classifier_token
is the [CLS]
token which we set at end of the input208 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216 self.classifier_token = classifier_token
217 self.padding_token = padding_token
218 self.seq_len = seq_len
219 self.vocab = vocab
220 self.tokenizer = tokenizer
batch
is the batch of data collected by the DataLoader
222 def __call__(self, batch):
Input data tensor, initialized with padding_token
228 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
Empty labels tensor
230 labels = torch.zeros(len(batch), dtype=torch.long)
Loop through the samples
233 for (i, (_label, _text)) in enumerate(batch):
Set the label
235 labels[i] = int(_label) - 1
Tokenize the input text
237 _text = [self.vocab[token] for token in self.tokenizer(_text)]
Truncate upto seq_len
239 _text = _text[:self.seq_len]
Transpose and add to data
241 data[:len(_text), i] = data.new_tensor(_text)
Set the final token in the sequence to [CLS]
244 data[-1, :] = self.classifier_token
247 return data, labels
This loads the AG News dataset and the set the values for n_classes
, vocab
, train_loader
, and valid_loader
.
250@option([NLPClassificationConfigs.n_classes,
251 NLPClassificationConfigs.vocab,
252 NLPClassificationConfigs.train_loader,
253 NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):
Get training and validation datasets
263 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
Load data to memory
266 with monit.section('Load data'):
267 from labml_nn.utils import MapStyleDataset
Create map-style datasets
270 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
Get tokenizer
273 tokenizer = c.tokenizer
Create a counter
276 counter = Counter()
Collect tokens from training dataset
278 for (label, line) in train:
279 counter.update(tokenizer(line))
Collect tokens from validation dataset
281 for (label, line) in valid:
282 counter.update(tokenizer(line))
Create vocabulary
284 vocab = torchtext.vocab.vocab(counter, min_freq=1)
Create training data loader
287 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Create validation data loader
290 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Return n_classes
, vocab
, train_loader
, and valid_loader
294 return 4, vocab, train_loader, valid_loader