1import torch
2import torch.nn as nn
3from labml import experiment
4from labml.configs import option
5from labml.utils.pytorch import get_modules
6from labml_helpers.module import Module
7
8from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
9from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
10from labml_nn.lstm import LSTM
13class AutoregressiveModel(Module):
18 def __init__(self, n_vocab: int, d_model: int, rnn_model: Module):
19 super().__init__()
令牌嵌入模块
21 self.src_embed = nn.Embedding(n_vocab, d_model)
22 self.lstm = rnn_model
23 self.generator = nn.Linear(d_model, n_vocab)
25 def forward(self, x: torch.Tensor):
26 x = self.src_embed(x)
嵌入令牌 (src
) 并通过变压器运行它
28 res, state = self.lstm(x)
生成下一个令牌的日志
30 return self.generator(res), state
33class Configs(NLPAutoRegressionConfigs):
40 model: AutoregressiveModel
41 rnn_model: Module
42
43 d_model: int = 512
44 n_rhn: int = 16
45 n_z: int = 16
初始化自回归模型
48@option(Configs.model)
49def autoregressive_model(c: Configs):
53 m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
54 return m.to(c.device)
57@option(Configs.rnn_model)
58def hyper_lstm(c: Configs):
59 return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
60
61
62@option(Configs.rnn_model)
63def lstm(c: Configs):
64 return LSTM(c.d_model, c.d_model, 1)
65
66
67def main():
创建实验
69 experiment.create(name="hyper_lstm", comment='')
创建配置
71 conf = Configs()
装载配置
73 experiment.configs(conf,
要覆盖的配置字典
75 {'tokenizer': 'character',
76 'text': 'tiny_shakespeare',
77 'optimizer.learning_rate': 2.5e-4,
78 'optimizer.optimizer': 'Adam',
79 'prompt': 'It is',
80 'prompt_separator': '',
81
82 'rnn_model': 'hyper_lstm',
83
84 'train_loader': 'shuffled_train_loader',
85 'valid_loader': 'shuffled_valid_loader',
86
87 'seq_len': 512,
88 'epochs': 128,
89 'batch_size': 2,
90 'inner_iterations': 25})
设置用于保存和加载的模型
93 experiment.add_pytorch_models(get_modules(conf))
开始实验
96 with experiment.start():
TrainValidConfigs.run
98 conf.run()
99
100
101if __name__ == '__main__':
102 main()