11from collections import Counter
12from typing import Callable
13
14import torchtext
15import torchtext.vocab
16from torchtext.vocab import Vocab
17
18import torch
19from labml import lab, tracker, monit
20from labml.configs import option
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.metrics import Accuracy
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
24from labml_nn.optimizers.configs import OptimizerConfigs
25from torch import nn
26from torch.utils.data import DataLoader
This has the basic configurations for NLP classification task training. All the properties are configurable.
29class NLPClassificationConfigs(TrainValidConfigs):
Optimizer
40 optimizer: torch.optim.Adam
Training device
42 device: torch.device = DeviceConfigs()
Autoregressive model
45 model: nn.Module
Batch size
47 batch_size: int = 16
Length of the sequence, or context size
49 seq_len: int = 512
Vocabulary
51 vocab: Vocab = 'ag_news'
Number of token in vocabulary
53 n_tokens: int
Number of classes
55 n_classes: int = 'ag_news'
Tokenizer
57 tokenizer: Callable = 'character'
Whether to periodically save models
60 is_save_models = True
Loss function
63 loss_func = nn.CrossEntropyLoss()
Accuracy function
65 accuracy = Accuracy()
Model embedding size
67 d_model: int = 512
Gradient clipping
69 grad_norm_clip: float = 1.0
Training data loader
72 train_loader: DataLoader = 'ag_news'
Validation data loader
74 valid_loader: DataLoader = 'ag_news'
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
79 is_log_model_params_grads: bool = False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
84 is_log_model_activations: bool = False
86 def init(self):
Set tracker configurations
91 tracker.set_scalar("accuracy.*", True)
92 tracker.set_scalar("loss.*", True)
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
97 self.state_modules = [self.accuracy]
99 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
105 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
108 if self.mode.is_train:
109 tracker.add_global_step(data.shape[1])
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
114 output, *_ = self.model(data)
Calculate and log loss
117 loss = self.loss_func(output, target)
118 tracker.add("loss.", loss)
Calculate and log accuracy
121 self.accuracy(output, target)
122 self.accuracy.track()
Train the model
125 if self.mode.is_train:
Calculate gradients
127 loss.backward()
Clip gradients
129 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
131 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
133 if batch_idx.is_last and self.is_log_model_params_grads:
134 tracker.add('model', self.model)
Clear the gradients
136 self.optimizer.zero_grad()
Save the tracked metrics
139 tracker.save()
142@option(NLPClassificationConfigs.optimizer)
143def _optimizer(c: NLPClassificationConfigs):
148 optimizer = OptimizerConfigs()
149 optimizer.parameters = c.model.parameters()
150 optimizer.optimizer = 'Adam'
151 optimizer.d_model = c.d_model
152
153 return optimizer
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
156@option(NLPClassificationConfigs.tokenizer)
157def basic_english():
171 from torchtext.data import get_tokenizer
172 return get_tokenizer('basic_english')
175def character_tokenizer(x: str):
179 return list(x)
Character level tokenizer configuration
182@option(NLPClassificationConfigs.tokenizer)
183def character():
187 return character_tokenizer
Get number of tokens
190@option(NLPClassificationConfigs.n_tokens)
191def _n_tokens(c: NLPClassificationConfigs):
195 return len(c.vocab) + 2
198class CollateFunc:
tokenizer
is the tokenizer function vocab
is the vocabulary seq_len
is the length of the sequence padding_token
is the token used for padding when the seq_len
is larger than the text length classifier_token
is the [CLS]
token which we set at end of the input203 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
211 self.classifier_token = classifier_token
212 self.padding_token = padding_token
213 self.seq_len = seq_len
214 self.vocab = vocab
215 self.tokenizer = tokenizer
batch
is the batch of data collected by the DataLoader
217 def __call__(self, batch):
Input data tensor, initialized with padding_token
223 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
Empty labels tensor
225 labels = torch.zeros(len(batch), dtype=torch.long)
Loop through the samples
228 for (i, (_label, _text)) in enumerate(batch):
Set the label
230 labels[i] = int(_label) - 1
Tokenize the input text
232 _text = [self.vocab[token] for token in self.tokenizer(_text)]
Truncate upto seq_len
234 _text = _text[:self.seq_len]
Transpose and add to data
236 data[:len(_text), i] = data.new_tensor(_text)
Set the final token in the sequence to [CLS]
239 data[-1, :] = self.classifier_token
242 return data, labels
This loads the AG News dataset and the set the values for n_classes
, vocab
, train_loader
, and valid_loader
.
245@option([NLPClassificationConfigs.n_classes,
246 NLPClassificationConfigs.vocab,
247 NLPClassificationConfigs.train_loader,
248 NLPClassificationConfigs.valid_loader])
249def ag_news(c: NLPClassificationConfigs):
Get training and validation datasets
258 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
Load data to memory
261 with monit.section('Load data'):
262 from labml_nn.utils import MapStyleDataset
Create map-style datasets
265 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
Get tokenizer
268 tokenizer = c.tokenizer
Create a counter
271 counter = Counter()
Collect tokens from training dataset
273 for (label, line) in train:
274 counter.update(tokenizer(line))
Collect tokens from validation dataset
276 for (label, line) in valid:
277 counter.update(tokenizer(line))
Create vocabulary
279 vocab = torchtext.vocab.vocab(counter, min_freq=1)
Create training data loader
282 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
283 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Create validation data loader
285 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
286 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Return n_classes
, vocab
, train_loader
, and valid_loader
289 return 4, vocab, train_loader, valid_loader