RETRO training

This is the training code for RETRO.

14import torch
15from torch import nn
16from torch.utils.data import DataLoader, RandomSampler
17
18from labml import monit, lab, tracker, experiment, logger
19from labml.logger import Text
20from labml_helpers.datasets.text import TextFileDataset
21from labml_nn.optimizers.noam import Noam
22from labml_nn.transformers.retro import model as retro
23from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
24from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder

Sampler

This class greedily samples from a model.

27class Sampler:
  • device is the device of the model
  • model is the Retro mode
  • tds is the text dataset (used to get neighbor chunks)
  • chunk_len is the length of a chunk
34    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
41        self.chunk_len = chunk_len
42        self.tds = tds
43        self.model = model
44        self.device = device
47        self.index = RetroIndex()

Retrieve nearest neighbors of a given chunk

49    def retrieve_nearest_neighbours(self, chunk: str):

Retrieve the offsets of the nearest neighbors

55        neighbor_offsets = self.index([chunk], None)

Get the neighbors (with neighbor length equal to chunk_len * 2 )

58        text = self.tds.train
59        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

62        return neighbors

Sample text from the given prompt

64    def sample(self, prompt: str, sample_len: int):

To store nearest neighbors as strings

70        neighbors_str = []

Sampled text

73        sampled = ''

Sample sample_len tokens

76        for i in range(sample_len):

We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for

79            while len(neighbors_str) < len(prompt) // self.chunk_len:

Get the last chunk for which we haven't retrieved neighbors

81                off = len(neighbors_str) * self.chunk_len
82                chunk = prompt[off: off + self.chunk_len]

Retrieve nearest neighbors

84                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

Tokenize the input

87            src = self.tds.text_to_i(prompt)

Tokenize the retrieved neighbors

89            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

Move them to the same device as the model

92            src = src.to(self.device)
93            neighbors = neighbors.to(self.device)

Get model output

96            res = self.model(src[None, :], neighbors[None, :, :, :])

Greedily sample the last token

99            token = res[0, -1, :].argmax(dim=-1)

Add the sampled token text to the prompt and sample text

102            prompt += self.tds.itos[token.item()]
103            sampled += self.tds.itos[token.item()]

106        return sampled

Retro trainer

109class Trainer:
114    def __init__(self, device: torch.device, model: retro.RetroModel,
115                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
122        self.optimizer = optimizer
123        self.device = device
124        self.dataloader = dataloader
125        self.model = model
126        self.loss_func = nn.CrossEntropyLoss()

Train the model for an epoch

128    def __call__(self):

Iterate through training data

134        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):

Move data to the device

136            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

Forward pass

139            res = self.model(src, neighbors)

Calculate loss

141            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

Clear the gradients

144            self.optimizer.zero_grad()

Backward pass

146            loss.backward()

Optimize the model

148            self.optimizer.step()

Save training statistics and increment the global step counter

151            tracker.save({'loss.train': loss})
152            tracker.add_global_step(len(src))

Create and train a small model

155def train():

Create an experiment

161    experiment.create(name='retro_small')

GPU device

164    device = torch.device('cuda:0')

Load Tiny Shakespeare dataset

167    tds = TextFileDataset(
168        lab.get_data_path() / 'tiny_shakespeare.txt',
169        list,
170        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
173    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

Create dataloader

176    train_dl = DataLoader(train_dataset,
177                          batch_size=4,
178                          sampler=RandomSampler(train_dataset, replacement=True))

Hyper-parameters

181    chunk_len = 16
182    d_model = 128
183    d_ff = 512
184    n_heads = 16
185    d_k = 16

Create the nearest neighbor encoder

188    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)

Create the model

190    model = RetroModel(tds.n_tokens, d_model, 6,
191                       {3, 5},
192                       chunk_len, n_heads, d_k, d_ff,
193                       encoder=nearest_neighbor_encoder)

Move the model to the device

195    model = model.to(device)

Create the optimizer

197    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)

Create the Trainer

199    trainer = Trainer(device, model, train_dl, optimizer)

Create the Sampler

201    sampler = Sampler(device, model, tds, chunk_len)

203    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

Set models for saving and loading

206    experiment.add_pytorch_models(model=model)

Start the experiment

209    with experiment.start():

Train for 32 epochs

211        for epoch in monit.loop(32):

Train

213            trainer()

Print a new line

215            tracker.new_line()

Sample from the prompt

217            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
218                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])

Save models

220            experiment.save_checkpoint()

224if __name__ == '__main__':
225    train()