1import torch
2import torch.nn as nn
3from labml import experiment
4from labml.configs import option
5from labml.utils.pytorch import get_modules
6
7from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
8from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
9from labml_nn.lstm import LSTM

Auto regressive model

12class AutoregressiveModel(nn.Module):
17    def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module):
18        super().__init__()

Token embedding module

20        self.src_embed = nn.Embedding(n_vocab, d_model)
21        self.lstm = rnn_model
22        self.generator = nn.Linear(d_model, n_vocab)
24    def forward(self, x: torch.Tensor):
25        x = self.src_embed(x)

Embed the tokens (src ) and run it through the the transformer

27        res, state = self.lstm(x)

Generate logits of the next token

29        return self.generator(res), state

Configurations

The default configs can and will be over-ridden when we start the experiment

32class Configs(NLPAutoRegressionConfigs):
39    model: AutoregressiveModel
40    rnn_model: nn.Module
41
42    d_model: int = 512
43    n_rhn: int = 16
44    n_z: int = 16

Initialize the auto-regressive model

47@option(Configs.model)
48def autoregressive_model(c: Configs):
52    m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
53    return m.to(c.device)
56@option(Configs.rnn_model)
57def hyper_lstm(c: Configs):
58    return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
59
60
61@option(Configs.rnn_model)
62def lstm(c: Configs):
63    return LSTM(c.d_model, c.d_model, 1)
64
65
66def main():

Create experiment

68    experiment.create(name="hyper_lstm", comment='')

Create configs

70    conf = Configs()

Load configurations

72    experiment.configs(conf,

A dictionary of configurations to override

74                       {'tokenizer': 'character',
75                        'text': 'tiny_shakespeare',
76                        'optimizer.learning_rate': 2.5e-4,
77                        'optimizer.optimizer': 'Adam',
78                        'prompt': 'It is',
79                        'prompt_separator': '',
80
81                        'rnn_model': 'hyper_lstm',
82
83                        'train_loader': 'shuffled_train_loader',
84                        'valid_loader': 'shuffled_valid_loader',
85
86                        'seq_len': 512,
87                        'epochs': 128,
88                        'batch_size': 2,
89                        'inner_iterations': 25})

Set models for saving and loading

92    experiment.add_pytorch_models(get_modules(conf))

Start the experiment

95    with experiment.start():

TrainValidConfigs.run

97        conf.run()
98
99
100if __name__ == '__main__':
101    main()