1import torch
2import torch.nn as nn
3from labml import experiment
4from labml.configs import option
5from labml.utils.pytorch import get_modules
6
7from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
8from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
9from labml_nn.lstm import LSTM
12class AutoregressiveModel(nn.Module):
17 def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module):
18 super().__init__()
Token embedding module
20 self.src_embed = nn.Embedding(n_vocab, d_model)
21 self.lstm = rnn_model
22 self.generator = nn.Linear(d_model, n_vocab)
24 def forward(self, x: torch.Tensor):
25 x = self.src_embed(x)
Embed the tokens (src
) and run it through the the transformer
27 res, state = self.lstm(x)
Generate logits of the next token
29 return self.generator(res), state
32class Configs(NLPAutoRegressionConfigs):
39 model: AutoregressiveModel
40 rnn_model: nn.Module
41
42 d_model: int = 512
43 n_rhn: int = 16
44 n_z: int = 16
Initialize the auto-regressive model
47@option(Configs.model)
48def autoregressive_model(c: Configs):
52 m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
53 return m.to(c.device)
56@option(Configs.rnn_model)
57def hyper_lstm(c: Configs):
58 return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
59
60
61@option(Configs.rnn_model)
62def lstm(c: Configs):
63 return LSTM(c.d_model, c.d_model, 1)
64
65
66def main():
Create experiment
68 experiment.create(name="hyper_lstm", comment='')
Create configs
70 conf = Configs()
Load configurations
72 experiment.configs(conf,
A dictionary of configurations to override
74 {'tokenizer': 'character',
75 'text': 'tiny_shakespeare',
76 'optimizer.learning_rate': 2.5e-4,
77 'optimizer.optimizer': 'Adam',
78 'prompt': 'It is',
79 'prompt_separator': '',
80
81 'rnn_model': 'hyper_lstm',
82
83 'train_loader': 'shuffled_train_loader',
84 'valid_loader': 'shuffled_valid_loader',
85
86 'seq_len': 512,
87 'epochs': 128,
88 'batch_size': 2,
89 'inner_iterations': 25})
Set models for saving and loading
92 experiment.add_pytorch_models(get_modules(conf))
Start the experiment
95 with experiment.start():
TrainValidConfigs.run
97 conf.run()
98
99
100if __name__ == '__main__':
101 main()