This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn't use labml.configs
module. We decided to write a simpler implementation to make it easier for readers who are not familiar.
19import dataclasses
20
21import torch
22from labml import experiment, lab, tracker, monit, logger
23from labml.logger import Text
24from labml.utils.download import download_file
25from labml_nn.experiments.nlp_autoregression import transpose_batch
26from labml_nn.optimizers.noam import Noam
27from labml_nn.transformers import Encoder, MultiHeadAttention
28from labml_nn.transformers.feed_forward import FeedForward
29from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
30from labml_nn.transformers.utils import subsequent_mask
31from torch import nn
32from torch.utils.data import Dataset, DataLoader
35class AutoregressiveModel(nn.Module):
40 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
41 super().__init__()
Token embedding module
43 self.src_embed = src_embed
Transformer based encoder
45 self.encoder = encoder
Next token generation layer; this gives logits of the the next token
48 self.generator = generator
This will be initialized on the first call
50 self.src_mask = None
52 def forward(self, src: torch.Tensor):
Create subsequent mask, so that the transformer can only pay attention to past tokens.
54 if self.src_mask is None or self.src_mask.size(0) != len(src):
55 self.src_mask = subsequent_mask(len(src)).to(src.device)
Embed the tokens (src
) and run it through the the transformer
57 res = self.encoder(self.src_embed(src), self.src_mask)
Generate logits of the next token
59 return self.generator(res)
62@dataclasses.dataclass
63class Configs:
67 d_model: int = 512
68 seq_len: int = 128
69 batch_size: int = 32
70 n_layers: int = 6
71 n_heads: int = 8
72 dropout: float = 0.1
73 d_ff: int = 2048
74 glu_variant: str = 'GLU'
75 epochs: int = 5
76 grad_norm_clip: float = 0.5
79class TinyShakespeareDataset(Dataset):
84 def __init__(self, seq_len: int):
Location of the text file
86 path = lab.get_data_path() / 'tiny_shakespeare.txt'
Download the file
88 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
Read the downloaded file
90 with open(str(path), 'r') as f:
91 text = f.read()
Extract the characters
94 chars = list(set(text))
Character to id (integer) map
96 self.stoi = {c: i for i, c in enumerate(chars)}
Id to character map
98 self.itos = {i: c for i, c in enumerate(chars)}
Length of a training sample
100 self.seq_len = seq_len
Data in the form of a tensor of ids
102 self.data = self.text_to_i(text)
Transform the text into a tensor of ids
104 def text_to_i(self, text: str):
108 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
110 def __len__(self):
116 return len(self.data) - self.seq_len - 1
Return a sample
118 def __getitem__(self, idx):
122 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
125class Trainer:
130 def __init__(self, configs: Configs):
Get the device
132 self.device = torch.device('cpu')
133 if torch.cuda.is_available():
134 self.device = torch.device('cuda:0')
Initialize the dataset
136 self.dataset = TinyShakespeareDataset(configs.seq_len)
Initialize the dataloader
138 self.dataloader = DataLoader(self.dataset,
139 batch_size=configs.batch_size,
140 collate_fn=transpose_batch,
141 shuffle=True)
FFN with Gated Linear Unit
145 if configs.glu_variant == 'GLU':
146 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
FFN with Bilinear hidden layer
149 elif configs.glu_variant == 'Bilinear':
150 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
FFN with ReLU gate
153 elif configs.glu_variant == 'ReGLU':
154 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
FFN with GELU gate
157 elif configs.glu_variant == 'GEGLU':
158 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
FFN with Swish gate where
162 elif configs.glu_variant == 'SwiGLU':
163 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
FFN with ReLU activation
166 elif configs.glu_variant == 'ReLU':
167 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
FFN with ReLU activation
170 elif configs.glu_variant == 'GELU':
171 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
172 else:
173 raise ValueError(f'Unknown variant {configs.glu_variant}')
Number of different characters
176 n_chars = len(self.dataset.stoi)
Initialize Multi-Head Attention module
179 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
Initialize the Transformer Block
181 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
182 feed_forward=ffn, dropout_prob=configs.dropout)
Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.
188 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
189 Encoder(transformer_layer, configs.n_layers),
190 nn.Linear(configs.d_model, n_chars))
Move the model to the current device
193 self.model.to(self.device)
Initialize Noam optimizer
196 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
Cross-entropy loss
199 self.loss_func = nn.CrossEntropyLoss()
Number of training epochs; note that our dataset definition repeats the data seq_len
times in a single epoch
202 self.epochs = configs.epochs
Gradient clipping norm
204 self.grad_norm_clip = configs.grad_norm_clip
Set tracker configurations
207 tracker.set_scalar("loss.*", True)
209 def sample(self):
Starting prompt
215 prompt = 'It is'
Collect output for printing
217 log = [(prompt, Text.subtle)]
Sample 25 tokens
219 for i in monit.iterate('Sample', 25):
Tokenize the prompt
221 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
222 data = data.to(self.device)
Get the model output
224 output = self.model(data)
Get the model prediction (greedy)
226 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
228 prompt += self.dataset.itos[output[-1].item()]
Add the prediction for logging
230 log += [(self.dataset.itos[output[-1].item()], Text.value)]
Print the sampled output
233 logger.log(log)
235 def train(self):
Loop for the given number of epochs
241 for _ in monit.loop(self.epochs):
Iterate over the minibatches
243 for i, batch in monit.enum('Train', self.dataloader):
Move data to the device
245 data, target = batch[0].to(self.device), batch[1].to(self.device)
Set tracker step, as the number of characters trained on
248 tracker.add_global_step(data.shape[0] * data.shape[1])
Set model state to training
251 self.model.train()
Evaluate the model
253 output = self.model(data)
Calculate loss
256 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
Log the loss
258 tracker.add("loss.train", loss)
Calculate gradients
261 loss.backward()
Clip gradients
263 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
265 self.optimizer.step()
Log the model parameters and gradients
267 if (i + 1) % 100 == 0:
268 tracker.add('model', self.model)
Clear the gradients
270 self.optimizer.zero_grad()
Generate a sample
273 if (i + 1) % 100 == 0:
274 self.model.eval()
275 with torch.no_grad():
276 self.sample()
Save the tracked metrics
279 if (i + 1) % 10 == 0:
280 tracker.save()
283def main():
Create experiment
285 experiment.create(name="glu_variants")
Create configs
287 configs = Configs()
Load configurations
289 experiment.configs(dataclasses.asdict(configs))
Create trainer
292 trainer = Trainer(configs)
Set models for training and loading
294 experiment.add_pytorch_models({'model': trainer.model})
Start the experiment
297 with experiment.start():
Train the model
299 trainer.train()
300
301
302if __name__ == '__main__':
303 main()