14import torch
15from labml import monit, lab, tracker, experiment, logger
16from labml.logger import Text
17from labml_nn.helpers.datasets import TextFileDataset
18from labml_nn.optimizers.noam import Noam
19from labml_nn.transformers.retro import model as retro
20from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
21from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
22from torch import nn
23from torch.utils.data import DataLoader, RandomSampler
26class Sampler:
device
is the device of the model model
is the Retro mode tds
is the text dataset (used to get neighbor chunks) chunk_len
is the length of a chunk33 def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
40 self.chunk_len = chunk_len
41 self.tds = tds
42 self.model = model
43 self.device = device
46 self.index = RetroIndex()
48 def retrieve_nearest_neighbours(self, chunk: str):
Retrieve the offsets of the nearest neighbors
54 neighbor_offsets = self.index([chunk], None)
Get the neighbors (with neighbor length equal to chunk_len * 2
)
57 text = self.tds.train
58 neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]
61 return neighbors
63 def sample(self, prompt: str, sample_len: int):
To store nearest neighbors as strings
69 neighbors_str = []
Sampled text
72 sampled = ''
Sample sample_len
tokens
75 for i in range(sample_len):
We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for
78 while len(neighbors_str) < len(prompt) // self.chunk_len:
Get the last chunk for which we haven't retrieved neighbors
80 off = len(neighbors_str) * self.chunk_len
81 chunk = prompt[off: off + self.chunk_len]
Retrieve nearest neighbors
83 neighbors_str.append(self.retrieve_nearest_neighbours(chunk))
Tokenize the input
86 src = self.tds.text_to_i(prompt)
Tokenize the retrieved neighbors
88 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])
Move them to the same device as the model
91 src = src.to(self.device)
92 neighbors = neighbors.to(self.device)
Get model output
95 res = self.model(src[None, :], neighbors[None, :, :, :])
Greedily sample the last token
98 token = res[0, -1, :].argmax(dim=-1)
Add the sampled token text to the prompt and sample text
101 prompt += self.tds.itos[token.item()]
102 sampled += self.tds.itos[token.item()]
105 return sampled
108class Trainer:
device
is the device of the model model
is the Retro mode dataloader
is the dataloader for the dataset with pre-retrieved neighbors optimizer
is the optimizer113 def __init__(self, device: torch.device, model: retro.RetroModel,
114 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
121 self.optimizer = optimizer
122 self.device = device
123 self.dataloader = dataloader
124 self.model = model
125 self.loss_func = nn.CrossEntropyLoss()
127 def __call__(self):
Iterate through training data
133 for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):
Move data to the device
135 src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)
Forward pass
138 res = self.model(src, neighbors)
Calculate loss
140 loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))
Clear the gradients
143 self.optimizer.zero_grad()
Backward pass
145 loss.backward()
Optimize the model
147 self.optimizer.step()
Save training statistics and increment the global step counter
150 tracker.save({'loss.train': loss})
151 tracker.add_global_step(len(src))
154def train():
Create an experiment
160 experiment.create(name='retro_small')
GPU device
163 device = torch.device('cuda:0')
Load Tiny Shakespeare dataset
166 tds = TextFileDataset(
167 lab.get_data_path() / 'tiny_shakespeare.txt',
168 list,
169 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Load Retro dataset
172 train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)
Create dataloader
175 train_dl = DataLoader(train_dataset,
176 batch_size=4,
177 sampler=RandomSampler(train_dataset, replacement=True))
Hyper-parameters
180 chunk_len = 16
181 d_model = 128
182 d_ff = 512
183 n_heads = 16
184 d_k = 16
Create the nearest neighbor encoder
187 nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
Create the model
189 model = RetroModel(tds.n_tokens, d_model, 6,
190 {3, 5},
191 chunk_len, n_heads, d_k, d_ff,
192 encoder=nearest_neighbor_encoder)
Move the model to the device
194 model = model.to(device)
Create the optimizer
196 optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
Create the Trainer
198 trainer = Trainer(device, model, train_dl, optimizer)
Create the Sampler
200 sampler = Sampler(device, model, tds, chunk_len)
202 prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
Set models for saving and loading
205 experiment.add_pytorch_models(model=model)
Start the experiment
208 with experiment.start():
Train for 32
epochs
210 for epoch in monit.loop(32):
Train
212 trainer()
Print a new line
214 tracker.new_line()
Sample from the prompt
216 logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
217 (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
Save models
222if __name__ == '__main__':
223 train()