1import torch
2import torch.nn as nn
3from labml import experiment
4from labml.configs import option
5from labml.utils.pytorch import get_modules
6from labml_helpers.module import Module
7
8from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
9from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
10from labml_nn.lstm import LSTM
13class AutoregressiveModel(Module):
18 def __init__(self, n_vocab: int, d_model: int, rnn_model: Module):
19 super().__init__()
トークン埋め込みモジュール
21 self.src_embed = nn.Embedding(n_vocab, d_model)
22 self.lstm = rnn_model
23 self.generator = nn.Linear(d_model, n_vocab)
25 def forward(self, x: torch.Tensor):
26 x = self.src_embed(x)
トークン (src
) を埋め込み、トランスフォーマーに通します
28 res, state = self.lstm(x)
次のトークンのロジットを生成
30 return self.generator(res), state
33class Configs(NLPAutoRegressionConfigs):
40 model: AutoregressiveModel
41 rnn_model: Module
42
43 d_model: int = 512
44 n_rhn: int = 16
45 n_z: int = 16
自己回帰モデルを初期化
48@option(Configs.model)
49def autoregressive_model(c: Configs):
53 m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
54 return m.to(c.device)
57@option(Configs.rnn_model)
58def hyper_lstm(c: Configs):
59 return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
60
61
62@option(Configs.rnn_model)
63def lstm(c: Configs):
64 return LSTM(c.d_model, c.d_model, 1)
65
66
67def main():
実験を作成
69 experiment.create(name="hyper_lstm", comment='')
コンフィグの作成
71 conf = Configs()
構成をロード
73 experiment.configs(conf,
オーバーライドする設定の辞書
75 {'tokenizer': 'character',
76 'text': 'tiny_shakespeare',
77 'optimizer.learning_rate': 2.5e-4,
78 'optimizer.optimizer': 'Adam',
79 'prompt': 'It is',
80 'prompt_separator': '',
81
82 'rnn_model': 'hyper_lstm',
83
84 'train_loader': 'shuffled_train_loader',
85 'valid_loader': 'shuffled_valid_loader',
86
87 'seq_len': 512,
88 'epochs': 128,
89 'batch_size': 2,
90 'inner_iterations': 25})
保存および読み込み用のモデルを設定する
93 experiment.add_pytorch_models(get_modules(conf))
実験を始める
96 with experiment.start():
TrainValidConfigs.run
98 conf.run()
99
100
101if __name__ == '__main__':
102 main()