这可以训练一个简单的变压器模型进行自动回归。我们为位置前馈网络尝试不同的变体。
这是一个不使用labml.configs
模块的更简单的实现。我们决定编写一个更简单的实现,让不熟悉的读者更容易使用。
19import dataclasses
20
21import torch
22from labml_helpers.module import Module
23from torch import nn
24from torch.utils.data import Dataset, DataLoader
25
26from labml import experiment, lab, tracker, monit, logger
27from labml.logger import Text
28from labml.utils.download import download_file
29from labml_nn.experiments.nlp_autoregression import transpose_batch
30from labml_nn.optimizers.noam import Noam
31from labml_nn.transformers import Encoder, MultiHeadAttention
32from labml_nn.transformers.feed_forward import FeedForward
33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
34from labml_nn.transformers.utils import subsequent_mask
37class AutoregressiveModel(Module):
42 def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
43 super().__init__()
令牌嵌入模块
45 self.src_embed = src_embed
基于变压器的编码器
47 self.encoder = encoder
下一代币生成层;这给出了下一个令牌的日志
50 self.generator = generator
这将在第一次调用时初始化
52 self.src_mask = None
54 def forward(self, src: torch.Tensor):
创建后续掩码,以便变压器只能关注过去的令牌。
56 if self.src_mask is None or self.src_mask.size(0) != len(src):
57 self.src_mask = subsequent_mask(len(src)).to(src.device)
嵌入令牌 (src
) 并通过变压器运行它
59 res = self.encoder(self.src_embed(src), self.src_mask)
生成下一个令牌的日志
61 return self.generator(res)
64@dataclasses.dataclass
65class Configs:
69 d_model: int = 512
70 seq_len: int = 128
71 batch_size: int = 32
72 n_layers: int = 6
73 n_heads: int = 8
74 dropout: float = 0.1
75 d_ff: int = 2048
76 glu_variant: str = 'GLU'
77 epochs: int = 5
78 grad_norm_clip: float = 0.5
81class TinyShakespeareDataset(Dataset):
86 def __init__(self, seq_len: int):
文本文件的位置
88 path = lab.get_data_path() / 'tiny_shakespeare.txt'
下载该文件
90 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
读取下载的文件
92 with open(str(path), 'r') as f:
93 text = f.read()
提取字符
96 chars = list(set(text))
字符到 id(整数)映射
98 self.stoi = {c: i for i, c in enumerate(chars)}
角色映射的 ID
100 self.itos = {i: c for i, c in enumerate(chars)}
训练样本的长度
102 self.seq_len = seq_len
以 id 张量形式显示的数据
104 self.data = self.text_to_i(text)
将文本转换为 id 张量
106 def text_to_i(self, text: str):
110 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
112 def __len__(self):
118 return len(self.data) - self.seq_len - 1
返回样品
120 def __getitem__(self, idx):
124 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
127class Trainer:
132 def __init__(self, configs: Configs):
拿到设备
134 self.device = torch.device('cpu')
135 if torch.cuda.is_available():
136 self.device = torch.device('cuda:0')
初始化数据集
138 self.dataset = TinyShakespeareDataset(configs.seq_len)
初始化数据加载器
140 self.dataloader = DataLoader(self.dataset,
141 batch_size=configs.batch_size,
142 collate_fn=transpose_batch,
143 shuffle=True)
带门控线性单元的 FFN
147 if configs.glu_variant == 'GLU':
148 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
带双线性隐藏层的 FFN
151 elif configs.glu_variant == 'Bilinear':
152 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
带有 ReLU 门的 FFN
155 elif configs.glu_variant == 'ReGLU':
156 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
带有 GELU 门的 FFN
159 elif configs.glu_variant == 'GEGLU':
160 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
FFN 有 Swish gate 在哪里
164 elif configs.glu_variant == 'SwiGLU':
165 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
激活 ReLU 的 FFN
168 elif configs.glu_variant == 'ReLU':
169 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
激活 ReLU 的 FFN
172 elif configs.glu_variant == 'GELU':
173 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
174 else:
175 raise ValueError(f'Unknown variant {configs.glu_variant}')
不同字符的数量
178 n_chars = len(self.dataset.stoi)
183 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
184 feed_forward=ffn, dropout_prob=configs.dropout)
190 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
191 Encoder(transformer_layer, configs.n_layers),
192 nn.Linear(configs.d_model, n_chars))
将模型移至当前设备
195 self.model.to(self.device)
198 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
交叉熵损失
201 self.loss_func = nn.CrossEntropyLoss()
训练周期的数量;请注意,我们的数据集定义在单个纪元中重复数据seq_len
时间
204 self.epochs = configs.epochs
渐变剪切规范
206 self.grad_norm_clip = configs.grad_norm_clip
设置跟踪器配置
209 tracker.set_scalar("loss.*", True)
211 def sample(self):
启动提示
217 prompt = 'It is'
收集输出以进行打印
219 log = [(prompt, Text.subtle)]
样本 25 个代币
221 for i in monit.iterate('Sample', 25):
将提示符号化
223 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
224 data = data.to(self.device)
获取模型输出
226 output = self.model(data)
获取模型预测(贪婪)
228 output = output.argmax(dim=-1).squeeze()
将预测添加到提示符中
230 prompt += self.dataset.itos[output[-1].item()]
添加日志记录的预测
232 log += [(self.dataset.itos[output[-1].item()], Text.value)]
打印采样输出
235 logger.log(log)
237 def train(self):
循环使用给定数量的周期
243 for _ in monit.loop(self.epochs):
遍历迷你批次
245 for i, batch in monit.enum('Train', self.dataloader):
将数据移动到设备
247 data, target = batch[0].to(self.device), batch[1].to(self.device)
将跟踪器步长设置为训练的字符数
250 tracker.add_global_step(data.shape[0] * data.shape[1])
将模型状态设置为训练
253 self.model.train()
评估模型
255 output = self.model(data)
计算损失
258 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
记录损失
260 tracker.add("loss.train", loss)
计算梯度
263 loss.backward()
剪辑渐变
265 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
267 self.optimizer.step()
记录模型参数和梯度
269 if (i + 1) % 100 == 0:
270 tracker.add('model', self.model)
清除渐变
272 self.optimizer.zero_grad()
生成样本
275 if (i + 1) % 100 == 0:
276 self.model.eval()
277 with torch.no_grad():
278 self.sample()
保存跟踪的指标
281 if (i + 1) % 10 == 0:
282 tracker.save()
保存模型
285 experiment.save_checkpoint()
288def main():
创建实验
290 experiment.create(name="glu_variants")
创建配置
292 configs = Configs()
装载配置
294 experiment.configs(dataclasses.asdict(configs))
创建训练器
297 trainer = Trainer(configs)
设置用于训练和加载的模型
299 experiment.add_pytorch_models({'model': trainer.model})
开始实验
302 with experiment.start():
训练模型
304 trainer.train()
305
306
307if __name__ == '__main__':
308 main()