Gated Linear Units and Variants

This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.

This is a simpler implementation that doesn’t use labml.configs module. We decided to write a simpler implementation to make it easier readers who are not familiar.

Open In Colab View Run

20import dataclasses
21
22import torch
23from torch import nn
24from torch.utils.data import Dataset, DataLoader
25
26from labml import experiment, lab, tracker, monit, logger
27from labml.logger import Text
28from labml.utils.download import download_file
29from labml_nn.experiments.nlp_autoregression import transpose_batch
30from labml_nn.optimizers.noam import Noam
31from labml_nn.transformers import Encoder, MultiHeadAttention
32from labml_nn.transformers.feed_forward import FeedForward
33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
34from labml_nn.transformers.utils import subsequent_mask

Auto regressive model

37class AutoregressiveModel(nn.Module):
42    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
43        super().__init__()

Token embedding module

45        self.src_embed = src_embed

Transformer based encoder

47        self.encoder = encoder

Next token generation layer; this give logits of the the next token

50        self.generator = generator

This will be initialized on the first call

52        self.src_mask = None
54    def __call__(self, src: torch.Tensor):

Create subsequent mask, so that the transformer can only pay attention to past tokens.

56        if self.src_mask is None or self.src_mask.size(0) != len(src):
57            self.src_mask = subsequent_mask(len(src)).to(src.device)

Embed the tokens (src) and run it through the the transformer

59        res = self.encoder(self.src_embed(src), self.src_mask)

Generate logits of the next token

61        return self.generator(res)

Configurations

64@dataclasses.dataclass
65class Configs:
69    d_model: int = 512
70    seq_len: int = 128
71    batch_size: int = 32
72    n_layers: int = 6
73    n_heads: int = 8
74    dropout: float = 0.1
75    d_ff: int = 2048
76    glu_variant: str = 'GLU'
77    epochs: int = 5
78    grad_norm_clip: float = 0.5

Tiny Shakespeare Dataset

81class TinyShakespeareDataset(Dataset):
86    def __init__(self, seq_len: int):

Location of the text file

88        path = lab.get_data_path() / 'tiny_shakespeare.txt'

Download the file

90        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)

Read the downloaded file

92        with open(str(path), 'r') as f:
93            text = f.read()

Extract the characters

96        chars = list(set(text))

Character to id (integer) map

98        self.stoi = {c: i for i, c in enumerate(chars)}

Id to character map

100        self.itos = {i: c for i, c in enumerate(chars)}

Length of a training sample

102        self.seq_len = seq_len

Data in the form of a tensor of ids

104        self.data = self.text_to_i(text)

Transform the text into a tensor of ids

106    def text_to_i(self, text: str):
110        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

Number of samples in the dataset.

This will read the dataset seq_len times in a single epoch.

112    def __len__(self):
118        return len(self.data) - self.seq_len - 1

Return a sample

120    def __getitem__(self, idx):
124        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]

Trainer

127class Trainer:
132    def __init__(self, configs: Configs):

Get the device

134        self.device = torch.device('cpu')
135        if torch.cuda.is_available():
136            self.device = torch.device('cuda:0')

Initialize the dataset

138        self.dataset = TinyShakespeareDataset(configs.seq_len)

Initialize the dataloader

140        self.dataloader = DataLoader(self.dataset,
141                                     batch_size=configs.batch_size,
142                                     collate_fn=transpose_batch,
143                                     shuffle=True)

FFN with Gated Linear Unit

147        if configs.glu_variant == 'GLU':
148            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)

FFN with Bilinear hidden layer

151        elif configs.glu_variant == 'Bilinear':
152            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)

FFN with ReLU gate

155        elif configs.glu_variant == 'ReGLU':
156            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)

FFN with GELU gate

159        elif configs.glu_variant == 'GEGLU':
160            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)

FFN with Swish gate where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

164        elif configs.glu_variant == 'SwiGLU':
165            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)

FFN with ReLU activation

168        elif configs.glu_variant == 'ReLU':
169            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())

FFN with ReLU activation

172        elif configs.glu_variant == 'GELU':
173            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
174        else:
175            raise ValueError(f'Unknown variant {configs.glu_variant}')

Number of different characters

178        n_chars = len(self.dataset.stoi)
181        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)

Initialize the Transformer Block

183        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
184                                             feed_forward=ffn, dropout_prob=configs.dropout)

Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.

190        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
191                                         Encoder(transformer_layer, configs.n_layers),
192                                         nn.Linear(configs.d_model, n_chars))

Move the model to the current device

195        self.model.to(self.device)

Initialize Noam optimizer

198        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)

Cross-entropy loss

201        self.loss_func = nn.CrossEntropyLoss()

Number of training epochs; *note that our dataset definition repeats the data seq_len times in a single epoch

204        self.epochs = configs.epochs

Gradient clipping norm

206        self.grad_norm_clip = configs.grad_norm_clip

Set tracker configurations

209        tracker.set_scalar("loss.*", True)

Sampling function to generate samples periodically while training

211    def sample(self):

Starting prompt

217        prompt = 'It is'

Collect output for printing

219        log = [(prompt, Text.subtle)]

Sample 25 tokens

221        for i in monit.iterate('Sample', 25):

Tokenize the prompt

223            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
224            data = data.to(self.device)

Get the model output

226            output = self.model(data)

Get the model prediction (greedy)

228            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

230            prompt += self.dataset.itos[output[-1].item()]

Add the prediction for logging

232            log += [(self.dataset.itos[output[-1].item()], Text.value)]

Print the sampled output

235        logger.log(log)

Train the model

237    def train(self):

Loop for the given number of epochs

243        for _ in monit.loop(self.epochs):

Iterate over the minibatches

245            for i, batch in monit.enum('Train', self.dataloader):

Move data to the device

247                data, target = batch[0].to(self.device), batch[1].to(self.device)

Set tracker step, as the number of characters trained on

250                tracker.add_global_step(data.shape[0] * data.shape[1])

Set model state to training

253                self.model.train()

Evaluate the model

255                output = self.model(data)

Calculate loss

258                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))

Log the loss

260                tracker.add("loss.train", loss)

Calculate gradients

263                loss.backward()

Clip gradients

265                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

267                self.optimizer.step()

Log the model parameters and gradients

269                if (i + 1) % 100 == 0:
270                    tracker.add('model', self.model)

Clear the gradients

272                self.optimizer.zero_grad()

Generate a sample

275                if (i + 1) % 100 == 0:
276                    self.model.eval()
277                    with torch.no_grad():
278                        self.sample()

Save the tracked metrics

281                if (i + 1) % 10 == 0:
282                    tracker.save()

Save the model

285            experiment.save_checkpoint()
288def main():

Create experiment

290    experiment.create(name="glu_variants")

Create configs

292    configs = Configs()

Load configurations

294    experiment.configs(dataclasses.asdict(configs))

Create trainer

297    trainer = Trainer(configs)

Set models for training and loading

299    experiment.add_pytorch_models({'model': trainer.model})

Start the experiment

302    with experiment.start():

Train the model

304        trainer.train()
305
306
307if __name__ == '__main__':
308    main()