Gated Linear Units and Variants

This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.

This is a simpler implementation that doesn't use labml.configs module. We decided to write a simpler implementation to make it easier for readers who are not familiar.

Open In Colab

19import dataclasses
20
21import torch
22from labml import experiment, lab, tracker, monit, logger
23from labml.logger import Text
24from labml.utils.download import download_file
25from labml_nn.experiments.nlp_autoregression import transpose_batch
26from labml_nn.optimizers.noam import Noam
27from labml_nn.transformers import Encoder, MultiHeadAttention
28from labml_nn.transformers.feed_forward import FeedForward
29from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
30from labml_nn.transformers.utils import subsequent_mask
31from torch import nn
32from torch.utils.data import Dataset, DataLoader

Auto regressive model

35class AutoregressiveModel(nn.Module):
40    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
41        super().__init__()

Token embedding module

43        self.src_embed = src_embed

Transformer based encoder

45        self.encoder = encoder

Next token generation layer; this gives logits of the the next token

48        self.generator = generator

This will be initialized on the first call

50        self.src_mask = None
52    def forward(self, src: torch.Tensor):

Create subsequent mask, so that the transformer can only pay attention to past tokens.

54        if self.src_mask is None or self.src_mask.size(0) != len(src):
55            self.src_mask = subsequent_mask(len(src)).to(src.device)

Embed the tokens (src ) and run it through the the transformer

57        res = self.encoder(self.src_embed(src), self.src_mask)

Generate logits of the next token

59        return self.generator(res)

Configurations

62@dataclasses.dataclass
63class Configs:
67    d_model: int = 512
68    seq_len: int = 128
69    batch_size: int = 32
70    n_layers: int = 6
71    n_heads: int = 8
72    dropout: float = 0.1
73    d_ff: int = 2048
74    glu_variant: str = 'GLU'
75    epochs: int = 5
76    grad_norm_clip: float = 0.5

Tiny Shakespeare Dataset

79class TinyShakespeareDataset(Dataset):
84    def __init__(self, seq_len: int):

Location of the text file

86        path = lab.get_data_path() / 'tiny_shakespeare.txt'

Download the file

88        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)

Read the downloaded file

90        with open(str(path), 'r') as f:
91            text = f.read()

Extract the characters

94        chars = list(set(text))

Character to id (integer) map

96        self.stoi = {c: i for i, c in enumerate(chars)}

Id to character map

98        self.itos = {i: c for i, c in enumerate(chars)}

Length of a training sample

100        self.seq_len = seq_len

Data in the form of a tensor of ids

102        self.data = self.text_to_i(text)

Transform the text into a tensor of ids

104    def text_to_i(self, text: str):
108        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

Number of samples in the dataset.

This will read the dataset seq_len times in a single epoch.

110    def __len__(self):
116        return len(self.data) - self.seq_len - 1

Return a sample

118    def __getitem__(self, idx):
122        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]

Trainer

125class Trainer:
130    def __init__(self, configs: Configs):

Get the device

132        self.device = torch.device('cpu')
133        if torch.cuda.is_available():
134            self.device = torch.device('cuda:0')

Initialize the dataset

136        self.dataset = TinyShakespeareDataset(configs.seq_len)

Initialize the dataloader

138        self.dataloader = DataLoader(self.dataset,
139                                     batch_size=configs.batch_size,
140                                     collate_fn=transpose_batch,
141                                     shuffle=True)

FFN with Gated Linear Unit

145        if configs.glu_variant == 'GLU':
146            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)

FFN with Bilinear hidden layer

149        elif configs.glu_variant == 'Bilinear':
150            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)

FFN with ReLU gate

153        elif configs.glu_variant == 'ReGLU':
154            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)

FFN with GELU gate

157        elif configs.glu_variant == 'GEGLU':
158            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)

FFN with Swish gate where

162        elif configs.glu_variant == 'SwiGLU':
163            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)

FFN with ReLU activation

166        elif configs.glu_variant == 'ReLU':
167            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())

FFN with ReLU activation

170        elif configs.glu_variant == 'GELU':
171            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
172        else:
173            raise ValueError(f'Unknown variant {configs.glu_variant}')

Number of different characters

176        n_chars = len(self.dataset.stoi)
179        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)

Initialize the Transformer Block

181        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
182                                             feed_forward=ffn, dropout_prob=configs.dropout)

Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.

188        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
189                                         Encoder(transformer_layer, configs.n_layers),
190                                         nn.Linear(configs.d_model, n_chars))

Move the model to the current device

193        self.model.to(self.device)

Initialize Noam optimizer

196        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)

Cross-entropy loss

199        self.loss_func = nn.CrossEntropyLoss()

Number of training epochs; note that our dataset definition repeats the data seq_len times in a single epoch

202        self.epochs = configs.epochs

Gradient clipping norm

204        self.grad_norm_clip = configs.grad_norm_clip

Set tracker configurations

207        tracker.set_scalar("loss.*", True)

Sampling function to generate samples periodically while training

209    def sample(self):

Starting prompt

215        prompt = 'It is'

Collect output for printing

217        log = [(prompt, Text.subtle)]

Sample 25 tokens

219        for i in monit.iterate('Sample', 25):

Tokenize the prompt

221            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
222            data = data.to(self.device)

Get the model output

224            output = self.model(data)

Get the model prediction (greedy)

226            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

228            prompt += self.dataset.itos[output[-1].item()]

Add the prediction for logging

230            log += [(self.dataset.itos[output[-1].item()], Text.value)]

Print the sampled output

233        logger.log(log)

Train the model

235    def train(self):

Loop for the given number of epochs

241        for _ in monit.loop(self.epochs):

Iterate over the minibatches

243            for i, batch in monit.enum('Train', self.dataloader):

Move data to the device

245                data, target = batch[0].to(self.device), batch[1].to(self.device)

Set tracker step, as the number of characters trained on

248                tracker.add_global_step(data.shape[0] * data.shape[1])

Set model state to training

251                self.model.train()

Evaluate the model

253                output = self.model(data)

Calculate loss

256                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))

Log the loss

258                tracker.add("loss.train", loss)

Calculate gradients

261                loss.backward()

Clip gradients

263                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

265                self.optimizer.step()

Log the model parameters and gradients

267                if (i + 1) % 100 == 0:
268                    tracker.add('model', self.model)

Clear the gradients

270                self.optimizer.zero_grad()

Generate a sample

273                if (i + 1) % 100 == 0:
274                    self.model.eval()
275                    with torch.no_grad():
276                        self.sample()

Save the tracked metrics

279                if (i + 1) % 10 == 0:
280                    tracker.save()
283def main():

Create experiment

285    experiment.create(name="glu_variants")

Create configs

287    configs = Configs()

Load configurations

289    experiment.configs(dataclasses.asdict(configs))

Create trainer

292    trainer = Trainer(configs)

Set models for training and loading

294    experiment.add_pytorch_models({'model': trainer.model})

Start the experiment

297    with experiment.start():

Train the model

299        trainer.train()
300
301
302if __name__ == '__main__':
303    main()