11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):
Optimizer
52 optimizer: torch.optim.Adam
Training device
54 device: torch.device = DeviceConfigs()
Autoregressive model
57 model: Module
Text dataset
59 text: TextDataset
Batch size
61 batch_size: int = 16
Length of the sequence, or context size
63 seq_len: int = 512
Number of token in vocabulary
65 n_tokens: int
Tokenizer
67 tokenizer: Callable = 'character'
Text prompt to start sampling (for illustration)
70 prompt: str
The token separator when sampling (blank for character level tokenization)
72 prompt_separator: str
Whether to periodically save models
75 is_save_models = True
Loss function
78 loss_func = CrossEntropyLoss()
Accuracy function
80 accuracy = Accuracy()
Model embedding size
82 d_model: int = 512
Gradient clipping
84 grad_norm_clip: float = 1.0
Training data loader
87 train_loader: DataLoader = 'shuffled_train_loader'
Validation data loader
89 valid_loader: DataLoader = 'shuffled_valid_loader'
Data loaders shuffle with replacement
92 dataloader_shuffle_with_replacement: bool = False
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
97 is_log_model_params_grads: bool = False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
102 is_log_model_activations: bool = False
104 def init(self):
Set tracker configurations
109 tracker.set_scalar("accuracy.*", True)
110 tracker.set_scalar("loss.*", True)
111 tracker.set_text("sampled", False)
Add a hook to log module outputs
113 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
118 self.state_modules = [self.accuracy]
Override to calculate and log other metrics
120 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
122 pass
124 def step(self, batch: any, batch_idx: BatchIndex):
Set training/eval mode
130 self.model.train(self.mode.is_train)
Move data to the device
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
140 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
144 output, *_ = self.model(data)
Calculate and log loss
147 loss = self.loss_func(output, target)
148 tracker.add("loss.", loss)
Calculate and log accuracy
151 self.accuracy(output, target)
152 self.accuracy.track()
153
154 self.other_metrics(output, target)
Train the model
157 if self.mode.is_train:
Calculate gradients
159 loss.backward()
Clip gradients
161 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
163 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
165 if batch_idx.is_last and self.is_log_model_params_grads:
166 tracker.add('model', self.model)
Clear the gradients
168 self.optimizer.zero_grad()
Save the tracked metrics
171 tracker.save()
173 def sample(self):
Starting prompt
179 prompt = self.prompt
Collect output for printing
181 log = [(prompt, Text.subtle)]
Sample 25 tokens
183 for i in monit.iterate('Sample', 25):
Tokenize the prompt
185 data = self.text.text_to_i(prompt).unsqueeze(-1)
186 data = data.to(self.device)
Get the model output
188 output, *_ = self.model(data)
Get the model prediction (greedy)
190 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
192 prompt += self.prompt_separator + self.text.itos[output[-1]]
Add the prediction for logging
194 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
195
196 tracker.add({'sampled': prompt})
Print the sampled output
198 logger.log(log)
201@option(NLPAutoRegressionConfigs.optimizer)
202def _optimizer(c: NLPAutoRegressionConfigs):
207 optimizer = OptimizerConfigs()
208 optimizer.parameters = c.model.parameters()
209 optimizer.optimizer = 'Adam'
210 optimizer.d_model = c.d_model
211
212 return optimizer
Get number of tokens
215@option(NLPAutoRegressionConfigs.n_tokens)
216def _n_tokens(c: NLPAutoRegressionConfigs):
220 return c.text.n_tokens
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
223@option(NLPAutoRegressionConfigs.tokenizer)
224def basic_english():
238 from torchtext.data import get_tokenizer
239 return get_tokenizer('basic_english')
242def character_tokenizer(x: str):
246 return list(x)
249@option(NLPAutoRegressionConfigs.tokenizer)
250def character():
254 return character_tokenizer
257@option(NLPAutoRegressionConfigs.text)
258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
264 return TextFileDataset(
265 lab.get_data_path() / 'tiny_shakespeare.txt',
266 c.tokenizer,
267 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
270@option(NLPAutoRegressionConfigs.train_loader)
271def sequential_train_loader(c: NLPAutoRegressionConfigs):
275 return SequentialDataLoader(text=c.text.train,
276 dataset=c.text,
277 batch_size=c.batch_size,
278 seq_len=c.seq_len)
281@option(NLPAutoRegressionConfigs.valid_loader)
282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
286 return SequentialDataLoader(text=c.text.valid,
287 dataset=c.text,
288 batch_size=c.batch_size,
289 seq_len=c.seq_len)
DataLoader
collects the batches on the first dimension. We need to transpose it to be sequence first.
292def transpose_batch(batch):
300 transposed_data = list(zip(*batch))
Stack the batch along the second dimension dim=1
302 src = torch.stack(transposed_data[0], dim=1)
303 tgt = torch.stack(transposed_data[1], dim=1)
304
305 return src, tgt
308@option(NLPAutoRegressionConfigs.train_loader)
309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
313 dataset = SequentialUnBatchedDataset(text=c.text.train,
314 dataset=c.text,
315 seq_len=c.seq_len)
316 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
317
318 return DataLoader(dataset,
319 batch_size=c.batch_size,
320 collate_fn=transpose_batch,
321 sampler=sampler)
324@option(NLPAutoRegressionConfigs.valid_loader)
325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
329 dataset = SequentialUnBatchedDataset(text=c.text.valid,
330 dataset=c.text,
331 seq_len=c.seq_len)
332 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
333
334 return DataLoader(dataset,
335 batch_size=c.batch_size,
336 collate_fn=transpose_batch,
337 sampler=sampler)