This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn’t use labml.configs
module.
We decided to write a simpler implementation to make it easier for readers who are not familiar.
20import dataclasses
21
22import torch
23from labml_helpers.module import Module
24from torch import nn
25from torch.utils.data import Dataset, DataLoader
26
27from labml import experiment, lab, tracker, monit, logger
28from labml.logger import Text
29from labml.utils.download import download_file
30from labml_nn.experiments.nlp_autoregression import transpose_batch
31from labml_nn.optimizers.noam import Noam
32from labml_nn.transformers import Encoder, MultiHeadAttention
33from labml_nn.transformers.feed_forward import FeedForward
34from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
35from labml_nn.transformers.utils import subsequent_mask
38class AutoregressiveModel(Module):
43 def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
44 super().__init__()
Token embedding module
46 self.src_embed = src_embed
Transformer based encoder
48 self.encoder = encoder
Next token generation layer; this gives logits of the the next token
51 self.generator = generator
This will be initialized on the first call
53 self.src_mask = None
55 def forward(self, src: torch.Tensor):
Create subsequent mask, so that the transformer can only pay attention to past tokens.
57 if self.src_mask is None or self.src_mask.size(0) != len(src):
58 self.src_mask = subsequent_mask(len(src)).to(src.device)
Embed the tokens (src
) and run it through the the transformer
60 res = self.encoder(self.src_embed(src), self.src_mask)
Generate logits of the next token
62 return self.generator(res)
65@dataclasses.dataclass
66class Configs:
70 d_model: int = 512
71 seq_len: int = 128
72 batch_size: int = 32
73 n_layers: int = 6
74 n_heads: int = 8
75 dropout: float = 0.1
76 d_ff: int = 2048
77 glu_variant: str = 'GLU'
78 epochs: int = 5
79 grad_norm_clip: float = 0.5
82class TinyShakespeareDataset(Dataset):
87 def __init__(self, seq_len: int):
Location of the text file
89 path = lab.get_data_path() / 'tiny_shakespeare.txt'
Download the file
91 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
Read the downloaded file
93 with open(str(path), 'r') as f:
94 text = f.read()
Extract the characters
97 chars = list(set(text))
Character to id (integer) map
99 self.stoi = {c: i for i, c in enumerate(chars)}
Id to character map
101 self.itos = {i: c for i, c in enumerate(chars)}
Length of a training sample
103 self.seq_len = seq_len
Data in the form of a tensor of ids
105 self.data = self.text_to_i(text)
Transform the text into a tensor of ids
107 def text_to_i(self, text: str):
111 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
113 def __len__(self):
119 return len(self.data) - self.seq_len - 1
Return a sample
121 def __getitem__(self, idx):
125 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
128class Trainer:
133 def __init__(self, configs: Configs):
Get the device
135 self.device = torch.device('cpu')
136 if torch.cuda.is_available():
137 self.device = torch.device('cuda:0')
Initialize the dataset
139 self.dataset = TinyShakespeareDataset(configs.seq_len)
Initialize the dataloader
141 self.dataloader = DataLoader(self.dataset,
142 batch_size=configs.batch_size,
143 collate_fn=transpose_batch,
144 shuffle=True)
FFN with Gated Linear Unit
148 if configs.glu_variant == 'GLU':
149 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
FFN with Bilinear hidden layer
152 elif configs.glu_variant == 'Bilinear':
153 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
FFN with ReLU gate
156 elif configs.glu_variant == 'ReGLU':
157 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
FFN with GELU gate
160 elif configs.glu_variant == 'GEGLU':
161 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
FFN with Swish gate where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
165 elif configs.glu_variant == 'SwiGLU':
166 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
FFN with ReLU activation
169 elif configs.glu_variant == 'ReLU':
170 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
FFN with ReLU activation
173 elif configs.glu_variant == 'GELU':
174 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
175 else:
176 raise ValueError(f'Unknown variant {configs.glu_variant}')
Number of different characters
179 n_chars = len(self.dataset.stoi)
Initialize Multi-Head Attention module
182 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
Initialize the Transformer Block
184 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
185 feed_forward=ffn, dropout_prob=configs.dropout)
Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.
191 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
192 Encoder(transformer_layer, configs.n_layers),
193 nn.Linear(configs.d_model, n_chars))
Move the model to the current device
196 self.model.to(self.device)
Initialize Noam optimizer
199 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
Cross-entropy loss
202 self.loss_func = nn.CrossEntropyLoss()
Number of training epochs;
*note that our dataset definition repeats the data seq_len
times in a single epoch
205 self.epochs = configs.epochs
Gradient clipping norm
207 self.grad_norm_clip = configs.grad_norm_clip
Set tracker configurations
210 tracker.set_scalar("loss.*", True)
212 def sample(self):
Starting prompt
218 prompt = 'It is'
Collect output for printing
220 log = [(prompt, Text.subtle)]
Sample 25 tokens
222 for i in monit.iterate('Sample', 25):
Tokenize the prompt
224 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
225 data = data.to(self.device)
Get the model output
227 output = self.model(data)
Get the model prediction (greedy)
229 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
231 prompt += self.dataset.itos[output[-1].item()]
Add the prediction for logging
233 log += [(self.dataset.itos[output[-1].item()], Text.value)]
Print the sampled output
236 logger.log(log)
238 def train(self):
Loop for the given number of epochs
244 for _ in monit.loop(self.epochs):
Iterate over the minibatches
246 for i, batch in monit.enum('Train', self.dataloader):
Move data to the device
248 data, target = batch[0].to(self.device), batch[1].to(self.device)
Set tracker step, as the number of characters trained on
251 tracker.add_global_step(data.shape[0] * data.shape[1])
Set model state to training
254 self.model.train()
Evaluate the model
256 output = self.model(data)
Calculate loss
259 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
Log the loss
261 tracker.add("loss.train", loss)
Calculate gradients
264 loss.backward()
Clip gradients
266 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
268 self.optimizer.step()
Log the model parameters and gradients
270 if (i + 1) % 100 == 0:
271 tracker.add('model', self.model)
Clear the gradients
273 self.optimizer.zero_grad()
Generate a sample
276 if (i + 1) % 100 == 0:
277 self.model.eval()
278 with torch.no_grad():
279 self.sample()
Save the tracked metrics
282 if (i + 1) % 10 == 0:
283 tracker.save()
Save the model
286 experiment.save_checkpoint()
289def main():
Create experiment
291 experiment.create(name="glu_variants")
Create configs
293 configs = Configs()
Load configurations
295 experiment.configs(dataclasses.asdict(configs))
Create trainer
298 trainer = Trainer(configs)
Set models for training and loading
300 experiment.add_pytorch_models({'model': trainer.model})
Start the experiment
303 with experiment.start():
Train the model
305 trainer.train()
306
307
308if __name__ == '__main__':
309 main()