Gated Linear Units and Variants

This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.

This is a simpler implementation that doesn’t use labml.configs module. We decided to write a simpler implementation to make it easier for readers who are not familiar.

Open In Colab View Run

20import dataclasses
21
22import torch
23from labml_helpers.module import Module
24from torch import nn
25from torch.utils.data import Dataset, DataLoader
26
27from labml import experiment, lab, tracker, monit, logger
28from labml.logger import Text
29from labml.utils.download import download_file
30from labml_nn.experiments.nlp_autoregression import transpose_batch
31from labml_nn.optimizers.noam import Noam
32from labml_nn.transformers import Encoder, MultiHeadAttention
33from labml_nn.transformers.feed_forward import FeedForward
34from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
35from labml_nn.transformers.utils import subsequent_mask

Auto regressive model

38class AutoregressiveModel(Module):
43    def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
44        super().__init__()

Token embedding module

46        self.src_embed = src_embed

Transformer based encoder

48        self.encoder = encoder

Next token generation layer; this gives logits of the the next token

51        self.generator = generator

This will be initialized on the first call

53        self.src_mask = None
55    def forward(self, src: torch.Tensor):

Create subsequent mask, so that the transformer can only pay attention to past tokens.

57        if self.src_mask is None or self.src_mask.size(0) != len(src):
58            self.src_mask = subsequent_mask(len(src)).to(src.device)

Embed the tokens (src) and run it through the the transformer

60        res = self.encoder(self.src_embed(src), self.src_mask)

Generate logits of the next token

62        return self.generator(res)

Configurations

65@dataclasses.dataclass
66class Configs:
70    d_model: int = 512
71    seq_len: int = 128
72    batch_size: int = 32
73    n_layers: int = 6
74    n_heads: int = 8
75    dropout: float = 0.1
76    d_ff: int = 2048
77    glu_variant: str = 'GLU'
78    epochs: int = 5
79    grad_norm_clip: float = 0.5

Tiny Shakespeare Dataset

82class TinyShakespeareDataset(Dataset):
87    def __init__(self, seq_len: int):

Location of the text file

89        path = lab.get_data_path() / 'tiny_shakespeare.txt'

Download the file

91        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)

Read the downloaded file

93        with open(str(path), 'r') as f:
94            text = f.read()

Extract the characters

97        chars = list(set(text))

Character to id (integer) map

99        self.stoi = {c: i for i, c in enumerate(chars)}

Id to character map

101        self.itos = {i: c for i, c in enumerate(chars)}

Length of a training sample

103        self.seq_len = seq_len

Data in the form of a tensor of ids

105        self.data = self.text_to_i(text)

Transform the text into a tensor of ids

107    def text_to_i(self, text: str):
111        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

Number of samples in the dataset.

This will read the dataset seq_len times in a single epoch.

113    def __len__(self):
119        return len(self.data) - self.seq_len - 1

Return a sample

121    def __getitem__(self, idx):
125        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]

Trainer

128class Trainer:
133    def __init__(self, configs: Configs):

Get the device

135        self.device = torch.device('cpu')
136        if torch.cuda.is_available():
137            self.device = torch.device('cuda:0')

Initialize the dataset

139        self.dataset = TinyShakespeareDataset(configs.seq_len)

Initialize the dataloader

141        self.dataloader = DataLoader(self.dataset,
142                                     batch_size=configs.batch_size,
143                                     collate_fn=transpose_batch,
144                                     shuffle=True)

FFN with Gated Linear Unit

148        if configs.glu_variant == 'GLU':
149            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)

FFN with Bilinear hidden layer

152        elif configs.glu_variant == 'Bilinear':
153            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)

FFN with ReLU gate

156        elif configs.glu_variant == 'ReGLU':
157            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)

FFN with GELU gate

160        elif configs.glu_variant == 'GEGLU':
161            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)

FFN with Swish gate where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

165        elif configs.glu_variant == 'SwiGLU':
166            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)

FFN with ReLU activation

169        elif configs.glu_variant == 'ReLU':
170            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())

FFN with ReLU activation

173        elif configs.glu_variant == 'GELU':
174            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
175        else:
176            raise ValueError(f'Unknown variant {configs.glu_variant}')

Number of different characters

179        n_chars = len(self.dataset.stoi)
182        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)

Initialize the Transformer Block

184        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
185                                             feed_forward=ffn, dropout_prob=configs.dropout)

Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.

191        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
192                                         Encoder(transformer_layer, configs.n_layers),
193                                         nn.Linear(configs.d_model, n_chars))

Move the model to the current device

196        self.model.to(self.device)

Initialize Noam optimizer

199        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)

Cross-entropy loss

202        self.loss_func = nn.CrossEntropyLoss()

Number of training epochs; *note that our dataset definition repeats the data seq_len times in a single epoch

205        self.epochs = configs.epochs

Gradient clipping norm

207        self.grad_norm_clip = configs.grad_norm_clip

Set tracker configurations

210        tracker.set_scalar("loss.*", True)

Sampling function to generate samples periodically while training

212    def sample(self):

Starting prompt

218        prompt = 'It is'

Collect output for printing

220        log = [(prompt, Text.subtle)]

Sample 25 tokens

222        for i in monit.iterate('Sample', 25):

Tokenize the prompt

224            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
225            data = data.to(self.device)

Get the model output

227            output = self.model(data)

Get the model prediction (greedy)

229            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

231            prompt += self.dataset.itos[output[-1].item()]

Add the prediction for logging

233            log += [(self.dataset.itos[output[-1].item()], Text.value)]

Print the sampled output

236        logger.log(log)

Train the model

238    def train(self):

Loop for the given number of epochs

244        for _ in monit.loop(self.epochs):

Iterate over the minibatches

246            for i, batch in monit.enum('Train', self.dataloader):

Move data to the device

248                data, target = batch[0].to(self.device), batch[1].to(self.device)

Set tracker step, as the number of characters trained on

251                tracker.add_global_step(data.shape[0] * data.shape[1])

Set model state to training

254                self.model.train()

Evaluate the model

256                output = self.model(data)

Calculate loss

259                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))

Log the loss

261                tracker.add("loss.train", loss)

Calculate gradients

264                loss.backward()

Clip gradients

266                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

268                self.optimizer.step()

Log the model parameters and gradients

270                if (i + 1) % 100 == 0:
271                    tracker.add('model', self.model)

Clear the gradients

273                self.optimizer.zero_grad()

Generate a sample

276                if (i + 1) % 100 == 0:
277                    self.model.eval()
278                    with torch.no_grad():
279                        self.sample()

Save the tracked metrics

282                if (i + 1) % 10 == 0:
283                    tracker.save()

Save the model

286            experiment.save_checkpoint()
289def main():

Create experiment

291    experiment.create(name="glu_variants")

Create configs

293    configs = Configs()

Load configurations

295    experiment.configs(dataclasses.asdict(configs))

Create trainer

298    trainer = Trainer(configs)

Set models for training and loading

300    experiment.add_pytorch_models({'model': trainer.model})

Start the experiment

303    with experiment.start():

Train the model

305        trainer.train()
306
307
308if __name__ == '__main__':
309    main()