门控线性单元和变体

这可以训练一个简单的变压器模型进行自动回归。我们为位置前馈网络尝试不同的变体。

这是一个不使用labml.configs 模块的更简单的实现。我们决定编写一个更简单的实现,让不熟悉的读者更容易使用。

Open In Colab

19import dataclasses
20
21import torch
22from labml_helpers.module import Module
23from torch import nn
24from torch.utils.data import Dataset, DataLoader
25
26from labml import experiment, lab, tracker, monit, logger
27from labml.logger import Text
28from labml.utils.download import download_file
29from labml_nn.experiments.nlp_autoregression import transpose_batch
30from labml_nn.optimizers.noam import Noam
31from labml_nn.transformers import Encoder, MultiHeadAttention
32from labml_nn.transformers.feed_forward import FeedForward
33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
34from labml_nn.transformers.utils import subsequent_mask

自动回归模型

37class AutoregressiveModel(Module):
42    def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
43        super().__init__()

令牌嵌入模块

45        self.src_embed = src_embed

基于变压器的编码器

47        self.encoder = encoder

下一代币生成层;这给出了下一个令牌的日志

50        self.generator = generator

这将在第一次调用时初始化

52        self.src_mask = None
54    def forward(self, src: torch.Tensor):

创建后续掩码,以便变压器只能关注过去的令牌。

56        if self.src_mask is None or self.src_mask.size(0) != len(src):
57            self.src_mask = subsequent_mask(len(src)).to(src.device)

嵌入令牌 (src ) 并通过变压器运行它

59        res = self.encoder(self.src_embed(src), self.src_mask)

生成下一个令牌的日志

61        return self.generator(res)

配置

64@dataclasses.dataclass
65class Configs:
69    d_model: int = 512
70    seq_len: int = 128
71    batch_size: int = 32
72    n_layers: int = 6
73    n_heads: int = 8
74    dropout: float = 0.1
75    d_ff: int = 2048
76    glu_variant: str = 'GLU'
77    epochs: int = 5
78    grad_norm_clip: float = 0.5

小莎士比亚数据集

81class TinyShakespeareDataset(Dataset):
86    def __init__(self, seq_len: int):

文本文件的位置

88        path = lab.get_data_path() / 'tiny_shakespeare.txt'

下载该文件

90        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)

读取下载的文件

92        with open(str(path), 'r') as f:
93            text = f.read()

提取字符

96        chars = list(set(text))

字符到 id(整数)映射

98        self.stoi = {c: i for i, c in enumerate(chars)}

角色映射的 ID

100        self.itos = {i: c for i, c in enumerate(chars)}

训练样本的长度

102        self.seq_len = seq_len

以 id 张量形式显示的数据

104        self.data = self.text_to_i(text)

将文本转换为 id 张量

106    def text_to_i(self, text: str):
110        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

数据集中的样本数。

这将读取单个纪元中的数据集seq_len 时间。

112    def __len__(self):
118        return len(self.data) - self.seq_len - 1

返回样品

120    def __getitem__(self, idx):
124        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]

训练师

127class Trainer:
132    def __init__(self, configs: Configs):

拿到设备

134        self.device = torch.device('cpu')
135        if torch.cuda.is_available():
136            self.device = torch.device('cuda:0')

初始化数据集

138        self.dataset = TinyShakespeareDataset(configs.seq_len)

初始化数据加载器

140        self.dataloader = DataLoader(self.dataset,
141                                     batch_size=configs.batch_size,
142                                     collate_fn=transpose_batch,
143                                     shuffle=True)

带门控线性单元的 FFN

147        if configs.glu_variant == 'GLU':
148            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)

带双线性隐藏层的 FFN

151        elif configs.glu_variant == 'Bilinear':
152            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)

带有 ReLU 门的 FFN

155        elif configs.glu_variant == 'ReGLU':
156            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)

带有 GELU 门的 FFN

159        elif configs.glu_variant == 'GEGLU':
160            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)

FFN 有 Swish gate 在哪里

164        elif configs.glu_variant == 'SwiGLU':
165            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)

激活 ReLU 的 FFN

168        elif configs.glu_variant == 'ReLU':
169            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())

激活 ReLU 的 FFN

172        elif configs.glu_variant == 'GELU':
173            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
174        else:
175            raise ValueError(f'Unknown variant {configs.glu_variant}')

不同字符的数量

178        n_chars = len(self.dataset.stoi)
181        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)

初始化变压器模块

183        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
184                                             feed_forward=ffn, dropout_prob=configs.dropout)
使用@@

嵌入层(具有固定位置编码)变压器编码器和线性层来初始化模型以生成对数。

190        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
191                                         Encoder(transformer_layer, configs.n_layers),
192                                         nn.Linear(configs.d_model, n_chars))

将模型移至当前设备

195        self.model.to(self.device)

初始化 Noam 优化器

198        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)

交叉熵损失

201        self.loss_func = nn.CrossEntropyLoss()

训练周期的数量;请注意,我们的数据集定义在单个纪元中重复数据seq_len 时间

204        self.epochs = configs.epochs

渐变剪切规范

206        self.grad_norm_clip = configs.grad_norm_clip

设置跟踪器配置

209        tracker.set_scalar("loss.*", True)

采样功能可在训练时定期生成样本

211    def sample(self):

启动提示

217        prompt = 'It is'

收集输出以进行打印

219        log = [(prompt, Text.subtle)]

样本 25 个代币

221        for i in monit.iterate('Sample', 25):

将提示符号化

223            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
224            data = data.to(self.device)

获取模型输出

226            output = self.model(data)

获取模型预测(贪婪)

228            output = output.argmax(dim=-1).squeeze()

将预测添加到提示符中

230            prompt += self.dataset.itos[output[-1].item()]

添加日志记录的预测

232            log += [(self.dataset.itos[output[-1].item()], Text.value)]

打印采样输出

235        logger.log(log)

训练模型

237    def train(self):

循环使用给定数量的周期

243        for _ in monit.loop(self.epochs):

遍历迷你批次

245            for i, batch in monit.enum('Train', self.dataloader):

将数据移动到设备

247                data, target = batch[0].to(self.device), batch[1].to(self.device)

将跟踪器步长设置为训练的字符数

250                tracker.add_global_step(data.shape[0] * data.shape[1])

将模型状态设置为训练

253                self.model.train()

评估模型

255                output = self.model(data)

计算损失

258                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))

记录损失

260                tracker.add("loss.train", loss)

计算梯度

263                loss.backward()

剪辑渐变

265                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

采取优化器步骤

267                self.optimizer.step()

记录模型参数和梯度

269                if (i + 1) % 100 == 0:
270                    tracker.add('model', self.model)

清除渐变

272                self.optimizer.zero_grad()

生成样本

275                if (i + 1) % 100 == 0:
276                    self.model.eval()
277                    with torch.no_grad():
278                        self.sample()

保存跟踪的指标

281                if (i + 1) % 10 == 0:
282                    tracker.save()

保存模型

285            experiment.save_checkpoint()
288def main():

创建实验

290    experiment.create(name="glu_variants")

创建配置

292    configs = Configs()

装载配置

294    experiment.configs(dataclasses.asdict(configs))

创建训练器

297    trainer = Trainer(configs)

设置用于训练和加载的模型

299    experiment.add_pytorch_models({'model': trainer.model})

开始实验

302    with experiment.start():

训练模型

304        trainer.train()
305
306
307if __name__ == '__main__':
308    main()