This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn't use labml.configs
 module. We decided to write a simpler implementation to make it easier for readers who are not familiar.
19import dataclasses
20
21import torch
22from labml import experiment, lab, tracker, monit, logger
23from labml.logger import Text
24from labml.utils.download import download_file
25from labml_nn.experiments.nlp_autoregression import transpose_batch
26from labml_nn.optimizers.noam import Noam
27from labml_nn.transformers import Encoder, MultiHeadAttention
28from labml_nn.transformers.feed_forward import FeedForward
29from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
30from labml_nn.transformers.utils import subsequent_mask
31from torch import nn
32from torch.utils.data import Dataset, DataLoader35class AutoregressiveModel(nn.Module):40    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
41        super().__init__()Token embedding module
43        self.src_embed = src_embedTransformer based encoder
45        self.encoder = encoderNext token generation layer; this gives logits of the the next token
48        self.generator = generatorThis will be initialized on the first call
50        self.src_mask = None52    def forward(self, src: torch.Tensor):Create subsequent mask, so that the transformer can only pay attention to past tokens.
54        if self.src_mask is None or self.src_mask.size(0) != len(src):
55            self.src_mask = subsequent_mask(len(src)).to(src.device)Embed the tokens (src
) and run it through the the transformer 
57        res = self.encoder(self.src_embed(src), self.src_mask)Generate logits of the next token
59        return self.generator(res)62@dataclasses.dataclass
63class Configs:67    d_model: int = 512
68    seq_len: int = 128
69    batch_size: int = 32
70    n_layers: int = 6
71    n_heads: int = 8
72    dropout: float = 0.1
73    d_ff: int = 2048
74    glu_variant: str = 'GLU'
75    epochs: int = 5
76    grad_norm_clip: float = 0.579class TinyShakespeareDataset(Dataset):84    def __init__(self, seq_len: int):Location of the text file
86        path = lab.get_data_path() / 'tiny_shakespeare.txt'Download the file
88        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)Read the downloaded file
90        with open(str(path), 'r') as f:
91            text = f.read()Extract the characters
94        chars = list(set(text))Character to id (integer) map
96        self.stoi = {c: i for i, c in enumerate(chars)}Id to character map
98        self.itos = {i: c for i, c in enumerate(chars)}Length of a training sample
100        self.seq_len = seq_lenData in the form of a tensor of ids
102        self.data = self.text_to_i(text)Transform the text into a tensor of ids
104    def text_to_i(self, text: str):108        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)110    def __len__(self):116        return len(self.data) - self.seq_len - 1Return a sample
118    def __getitem__(self, idx):122        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]125class Trainer:130    def __init__(self, configs: Configs):Get the device
132        self.device = torch.device('cpu')
133        if torch.cuda.is_available():
134            self.device = torch.device('cuda:0')Initialize the dataset
136        self.dataset = TinyShakespeareDataset(configs.seq_len)Initialize the dataloader
138        self.dataloader = DataLoader(self.dataset,
139                                     batch_size=configs.batch_size,
140                                     collate_fn=transpose_batch,
141                                     shuffle=True)FFN with Gated Linear Unit
145        if configs.glu_variant == 'GLU':
146            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)FFN with Bilinear hidden layer
149        elif configs.glu_variant == 'Bilinear':
150            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)FFN with ReLU gate
153        elif configs.glu_variant == 'ReGLU':
154            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)FFN with GELU gate
157        elif configs.glu_variant == 'GEGLU':
158            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)FFN with Swish gate where
162        elif configs.glu_variant == 'SwiGLU':
163            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)FFN with ReLU activation
166        elif configs.glu_variant == 'ReLU':
167            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())FFN with ReLU activation
170        elif configs.glu_variant == 'GELU':
171            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
172        else:
173            raise ValueError(f'Unknown variant {configs.glu_variant}')Number of different characters
176        n_chars = len(self.dataset.stoi)Initialize Multi-Head Attention module
179        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)Initialize the Transformer Block
181        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
182                                             feed_forward=ffn, dropout_prob=configs.dropout)Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.
188        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
189                                         Encoder(transformer_layer, configs.n_layers),
190                                         nn.Linear(configs.d_model, n_chars))Move the model to the current device
193        self.model.to(self.device)Initialize Noam optimizer
196        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)Cross-entropy loss
199        self.loss_func = nn.CrossEntropyLoss()Number of training epochs; note that our dataset definition repeats the data seq_len
 times in a single epoch 
202        self.epochs = configs.epochsGradient clipping norm
204        self.grad_norm_clip = configs.grad_norm_clipSet tracker configurations
207        tracker.set_scalar("loss.*", True)209    def sample(self):Starting prompt
215        prompt = 'It is'Collect output for printing
217        log = [(prompt, Text.subtle)]Sample 25 tokens
219        for i in monit.iterate('Sample', 25):Tokenize the prompt
221            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
222            data = data.to(self.device)Get the model output
224            output = self.model(data)Get the model prediction (greedy)
226            output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
228            prompt += self.dataset.itos[output[-1].item()]Add the prediction for logging
230            log += [(self.dataset.itos[output[-1].item()], Text.value)]Print the sampled output
233        logger.log(log)235    def train(self):Loop for the given number of epochs
241        for _ in monit.loop(self.epochs):Iterate over the minibatches
243            for i, batch in monit.enum('Train', self.dataloader):Move data to the device
245                data, target = batch[0].to(self.device), batch[1].to(self.device)Set tracker step, as the number of characters trained on
248                tracker.add_global_step(data.shape[0] * data.shape[1])Set model state to training
251                self.model.train()Evaluate the model
253                output = self.model(data)Calculate loss
256                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))Log the loss
258                tracker.add("loss.train", loss)Calculate gradients
261                loss.backward()Clip gradients
263                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
265                self.optimizer.step()Log the model parameters and gradients
267                if (i + 1) % 100 == 0:
268                    tracker.add('model', self.model)Clear the gradients
270                self.optimizer.zero_grad()Generate a sample
273                if (i + 1) % 100 == 0:
274                    self.model.eval()
275                    with torch.no_grad():
276                        self.sample()Save the tracked metrics
279                if (i + 1) % 10 == 0:
280                    tracker.save()283def main():Create experiment
285    experiment.create(name="glu_variants")Create configs
287    configs = Configs()Load configurations
289    experiment.configs(dataclasses.asdict(configs))Create trainer
292    trainer = Trainer(configs)Set models for training and loading
294    experiment.add_pytorch_models({'model': trainer.model})Start the experiment
297    with experiment.start():Train the model
299        trainer.train()
300
301
302if __name__ == '__main__':
303    main()