This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn’t use labml.configs module.
We decided to write a simpler implementation to make it easier readers who are not familiar.
20import dataclasses
21
22import torch
23from torch import nn
24from torch.utils.data import Dataset, DataLoader
25
26from labml import experiment, lab, tracker, monit, logger
27from labml.logger import Text
28from labml.utils.download import download_file
29from labml_nn.experiments.nlp_autoregression import transpose_batch
30from labml_nn.optimizers.noam import Noam
31from labml_nn.transformers import Encoder, MultiHeadAttention
32from labml_nn.transformers.feed_forward import FeedForward
33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
34from labml_nn.transformers.utils import subsequent_mask37class AutoregressiveModel(nn.Module):42 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
43 super().__init__()Token embedding module
45 self.src_embed = src_embedTransformer based encoder
47 self.encoder = encoderNext token generation layer; this give logits of the the next token
50 self.generator = generatorThis will be initialized on the first call
52 self.src_mask = None54 def __call__(self, src: torch.Tensor):Create subsequent mask, so that the transformer can only pay attention to past tokens.
56 if self.src_mask is None or self.src_mask.size(0) != len(src):
57 self.src_mask = subsequent_mask(len(src)).to(src.device)Embed the tokens (src) and run it through the the transformer
59 res = self.encoder(self.src_embed(src), self.src_mask)Generate logits of the next token
61 return self.generator(res)64@dataclasses.dataclass
65class Configs:69 d_model: int = 512
70 seq_len: int = 128
71 batch_size: int = 32
72 n_layers: int = 6
73 n_heads: int = 8
74 dropout: float = 0.1
75 d_ff: int = 2048
76 glu_variant: str = 'GLU'
77 epochs: int = 5
78 grad_norm_clip: float = 0.581class TinyShakespeareDataset(Dataset):86 def __init__(self, seq_len: int):Location of the text file
88 path = lab.get_data_path() / 'tiny_shakespeare.txt'Download the file
90 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)Read the downloaded file
92 with open(str(path), 'r') as f:
93 text = f.read()Extract the characters
96 chars = list(set(text))Character to id (integer) map
98 self.stoi = {c: i for i, c in enumerate(chars)}Id to character map
100 self.itos = {i: c for i, c in enumerate(chars)}Length of a training sample
102 self.seq_len = seq_lenData in the form of a tensor of ids
104 self.data = self.text_to_i(text)Transform the text into a tensor of ids
106 def text_to_i(self, text: str):110 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)112 def __len__(self):118 return len(self.data) - self.seq_len - 1Return a sample
120 def __getitem__(self, idx):124 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]127class Trainer:132 def __init__(self, configs: Configs):Get the device
134 self.device = torch.device('cpu')
135 if torch.cuda.is_available():
136 self.device = torch.device('cuda:0')Initialize the dataset
138 self.dataset = TinyShakespeareDataset(configs.seq_len)Initialize the dataloader
140 self.dataloader = DataLoader(self.dataset,
141 batch_size=configs.batch_size,
142 collate_fn=transpose_batch,
143 shuffle=True)FFN with Gated Linear Unit
147 if configs.glu_variant == 'GLU':
148 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)FFN with Bilinear hidden layer
151 elif configs.glu_variant == 'Bilinear':
152 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)FFN with ReLU gate
155 elif configs.glu_variant == 'ReGLU':
156 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)FFN with GELU gate
159 elif configs.glu_variant == 'GEGLU':
160 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)FFN with Swish gate where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
164 elif configs.glu_variant == 'SwiGLU':
165 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)FFN with ReLU activation
168 elif configs.glu_variant == 'ReLU':
169 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())FFN with ReLU activation
172 elif configs.glu_variant == 'GELU':
173 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
174 else:
175 raise ValueError(f'Unknown variant {configs.glu_variant}')Number of different characters
178 n_chars = len(self.dataset.stoi)Initialize Multi-Head Attention module
181 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)Initialize the Transformer Block
183 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
184 feed_forward=ffn, dropout_prob=configs.dropout)Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.
190 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
191 Encoder(transformer_layer, configs.n_layers),
192 nn.Linear(configs.d_model, n_chars))Move the model to the current device
195 self.model.to(self.device)Initialize Noam optimizer
198 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)Cross-entropy loss
201 self.loss_func = nn.CrossEntropyLoss()Number of training epochs;
*note that our dataset definition repeats the data seq_len times in a single epoch
204 self.epochs = configs.epochsGradient clipping norm
206 self.grad_norm_clip = configs.grad_norm_clipSet tracker configurations
209 tracker.set_scalar("loss.*", True)211 def sample(self):Starting prompt
217 prompt = 'It is'Collect output for printing
219 log = [(prompt, Text.subtle)]Sample 25 tokens
221 for i in monit.iterate('Sample', 25):Tokenize the prompt
223 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
224 data = data.to(self.device)Get the model output
226 output = self.model(data)Get the model prediction (greedy)
228 output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
230 prompt += self.dataset.itos[output[-1].item()]Add the prediction for logging
232 log += [(self.dataset.itos[output[-1].item()], Text.value)]Print the sampled output
235 logger.log(log)237 def train(self):Loop for the given number of epochs
243 for _ in monit.loop(self.epochs):Iterate over the minibatches
245 for i, batch in monit.enum('Train', self.dataloader):Move data to the device
247 data, target = batch[0].to(self.device), batch[1].to(self.device)Set tracker step, as the number of characters trained on
250 tracker.add_global_step(data.shape[0] * data.shape[1])Set model state to training
253 self.model.train()Evaluate the model
255 output = self.model(data)Calculate loss
258 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))Log the loss
260 tracker.add("loss.train", loss)Calculate gradients
263 loss.backward()Clip gradients
265 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
267 self.optimizer.step()Log the model parameters and gradients
269 if (i + 1) % 100 == 0:
270 tracker.add('model', self.model)Clear the gradients
272 self.optimizer.zero_grad()Generate a sample
275 if (i + 1) % 100 == 0:
276 self.model.eval()
277 with torch.no_grad():
278 self.sample()Save the tracked metrics
281 if (i + 1) % 10 == 0:
282 tracker.save()Save the model
285 experiment.save_checkpoint()288def main():Create experiment
290 experiment.create(name="glu_variants")Create configs
292 configs = Configs()Load configurations
294 experiment.configs(dataclasses.asdict(configs))Create trainer
297 trainer = Trainer(configs)Set models for training and loading
299 experiment.add_pytorch_models({'model': trainer.model})Start the experiment
302 with experiment.start():Train the model
304 trainer.train()
305
306
307if __name__ == '__main__':
308 main()