RETRO training

This is the training code for RETRO.

14import torch
15from labml import monit, lab, tracker, experiment, logger
16from labml.logger import Text
17from labml_nn.helpers.datasets import TextFileDataset
18from labml_nn.optimizers.noam import Noam
19from labml_nn.transformers.retro import model as retro
20from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
21from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
22from torch import nn
23from torch.utils.data import DataLoader, RandomSampler

Sampler

This class greedily samples from a model.

26class Sampler:
  • device is the device of the model
  • model is the Retro mode
  • tds is the text dataset (used to get neighbor chunks)
  • chunk_len is the length of a chunk
33    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
40        self.chunk_len = chunk_len
41        self.tds = tds
42        self.model = model
43        self.device = device
46        self.index = RetroIndex()

Retrieve nearest neighbors of a given chunk

48    def retrieve_nearest_neighbours(self, chunk: str):

Retrieve the offsets of the nearest neighbors

54        neighbor_offsets = self.index([chunk], None)

Get the neighbors (with neighbor length equal to chunk_len * 2 )

57        text = self.tds.train
58        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

61        return neighbors

Sample text from the given prompt

63    def sample(self, prompt: str, sample_len: int):

To store nearest neighbors as strings

69        neighbors_str = []

Sampled text

72        sampled = ''

Sample sample_len tokens

75        for i in range(sample_len):

We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for

78            while len(neighbors_str) < len(prompt) // self.chunk_len:

Get the last chunk for which we haven't retrieved neighbors

80                off = len(neighbors_str) * self.chunk_len
81                chunk = prompt[off: off + self.chunk_len]

Retrieve nearest neighbors

83                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

Tokenize the input

86            src = self.tds.text_to_i(prompt)

Tokenize the retrieved neighbors

88            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

Move them to the same device as the model

91            src = src.to(self.device)
92            neighbors = neighbors.to(self.device)

Get model output

95            res = self.model(src[None, :], neighbors[None, :, :, :])

Greedily sample the last token

98            token = res[0, -1, :].argmax(dim=-1)

Add the sampled token text to the prompt and sample text

101            prompt += self.tds.itos[token.item()]
102            sampled += self.tds.itos[token.item()]

105        return sampled

Retro trainer

108class Trainer:
113    def __init__(self, device: torch.device, model: retro.RetroModel,
114                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
121        self.optimizer = optimizer
122        self.device = device
123        self.dataloader = dataloader
124        self.model = model
125        self.loss_func = nn.CrossEntropyLoss()

Train the model for an epoch

127    def __call__(self):

Iterate through training data

133        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):

Move data to the device

135            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

Forward pass

138            res = self.model(src, neighbors)

Calculate loss

140            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

Clear the gradients

143            self.optimizer.zero_grad()

Backward pass

145            loss.backward()

Optimize the model

147            self.optimizer.step()

Save training statistics and increment the global step counter

150            tracker.save({'loss.train': loss})
151            tracker.add_global_step(len(src))

Create and train a small model

154def train():

Create an experiment

160    experiment.create(name='retro_small')

GPU device

163    device = torch.device('cuda:0')

Load Tiny Shakespeare dataset

166    tds = TextFileDataset(
167        lab.get_data_path() / 'tiny_shakespeare.txt',
168        list,
169        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
172    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

Create dataloader

175    train_dl = DataLoader(train_dataset,
176                          batch_size=4,
177                          sampler=RandomSampler(train_dataset, replacement=True))

Hyper-parameters

180    chunk_len = 16
181    d_model = 128
182    d_ff = 512
183    n_heads = 16
184    d_k = 16

Create the nearest neighbor encoder

187    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)

Create the model

189    model = RetroModel(tds.n_tokens, d_model, 6,
190                       {3, 5},
191                       chunk_len, n_heads, d_k, d_ff,
192                       encoder=nearest_neighbor_encoder)

Move the model to the device

194    model = model.to(device)

Create the optimizer

196    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)

Create the Trainer

198    trainer = Trainer(device, model, train_dl, optimizer)

Create the Sampler

200    sampler = Sampler(device, model, tds, chunk_len)

202    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

Set models for saving and loading

205    experiment.add_pytorch_models(model=model)

Start the experiment

208    with experiment.start():

Train for 32 epochs

210        for epoch in monit.loop(32):

Train

212            trainer()

Print a new line

214            tracker.new_line()

Sample from the prompt

216            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
217                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])

Save models

222if __name__ == '__main__':
223    train()