在 Cora 数据集上训练图注意力网络 (GAT)

11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs

Cora 数据集

Cora 数据集是研究论文的数据集。对于每篇论文,我们都得到一个二进制特征向量,该向量表示单词的存在。每篇论文分为 7 个类别之一。该数据集还具有引文网络。

论文是图的节点,边缘是引文。

任务是使用特征向量和引文网络作为输入,将节点分类为 7 类。

26class CoraDataset:

每个节点的标签

41    labels: torch.Tensor

一组类名和一个唯一的整数索引

43    classes: Dict[str, int]

所有节点的特征向量

45    features: torch.Tensor

包含边信息的邻接矩阵。adj_mat[i][j] True 如果存在从i 到的边缘j

48    adj_mat: torch.Tensor

下载数据集

50    @staticmethod
51    def _download():
55        if not (lab.get_data_path() / 'cora').exists():
56            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57                                   lab.get_data_path() / 'cora.tgz')
58            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())

加载数据集

60    def __init__(self, include_edges: bool = True):

是否包括边缘。这是测试如果我们忽略引文网络会损失多少准确性。

67        self.include_edges = include_edges

下载数据集

70        self._download()

阅读纸张 ID、特征矢量和标签

73        with monit.section('Read content file'):
74            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))

加载引文,这是一个整数对的列表。

76        with monit.section('Read citations file'):
77            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)

获取特征向量

80        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))

归一化特征向量

82        self.features = features / features.sum(dim=1, keepdim=True)

获取类名并为每个类分配一个唯一的整数

85        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}

获取这些整数的标签

87        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)

获取纸质证件

90        paper_ids = np.array(content[:, 0], dtype=np.int32)

纸张 ID 到索引的映射

92        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}

空邻接矩阵-恒等矩阵

95        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)

在邻接矩阵中标记引用

98        if self.include_edges:
99            for e in citations:

一对纸质索引

101                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]

我们构建一个对称的图形,如果纸张引用了纸张,我们会在其中放置一个从到的徽章以及从

105                self.adj_mat[e1][e2] = True
106                self.adj_mat[e2][e1] = True

Graph 注意力网络 (GAT)

这个图形关注网络有两个图形关注层

109class GAT(Module):
  • in_features 是每个节点的要素数
  • n_hidden 是第一个图形关注层中的要素数
  • n_classes 是类的数量
  • n_heads 是图表关注层中的头部数量
  • dropout 是辍学概率
116    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124        super().__init__()

我们连接头部的第一个图形注意层

127        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)

第一个图形关注层之后的激活功能

129        self.activation = nn.ELU()

最后一张图关注层,我们平均头部

131        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)

辍学

133        self.dropout = nn.Dropout(dropout)
  • x 是形状的特征向量[n_nodes, in_features]
  • adj_mat 是形式的邻接矩阵[n_nodes, n_nodes, n_heads][n_nodes, n_nodes, 1]
135    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

将丢失应用于输入

142        x = self.dropout(x)

第一个图形关注层

144        x = self.layer1(x, adj_mat)

激活功能

146        x = self.activation(x)

辍学

148        x = self.dropout(x)

logits 的输出层(未激活)

150        return self.output(x, adj_mat)

计算精度的简单函数

153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)

配置

160class Configs(BaseConfigs):

型号

166    model: GAT

要训练的节点数

168    training_samples: int = 500

输入中每个节点的要素数

170    in_features: int

第一个图形关注图层中的要素数

172    n_hidden: int = 64

头数

174    n_heads: int = 8

用于分类的类数

176    n_classes: int

辍学概率

178    dropout: float = 0.6

是否包括引文网络

180    include_edges: bool = True

数据集

182    dataset: CoraDataset

训练迭代次数

184    epochs: int = 1_000

亏损函数

186    loss_func = nn.CrossEntropyLoss()

用于训练的设备

这将为设备创建配置,以便我们可以通过传递配置值来更改设备

191    device: torch.device = DeviceConfigs()

优化器

193    optimizer: torch.optim.Adam

训练循环

由于数据集很小,我们进行全批量训练。如果要进行采样和训练,我们将不得不为每个训练步骤对一组节点以及跨越这些选定节点的边进行采样。

195    def run(self):

将特征向量移动到设备

205        features = self.dataset.features.to(self.device)

将标签移到设备上

207        labels = self.dataset.labels.to(self.device)

将邻接矩阵移至设备

209        edges_adj = self.dataset.adj_mat.to(self.device)

为头部添加一个空的第三个维度

211        edges_adj = edges_adj.unsqueeze(-1)

随机索引

214        idx_rand = torch.randperm(len(labels))

训练节点

216        idx_train = idx_rand[:self.training_samples]

用于验证的节点

218        idx_valid = idx_rand[self.training_samples:]

训练循环

221        for epoch in monit.loop(self.epochs):

将模型设置为训练模式

223            self.model.train()

将所有渐变设为零

225            self.optimizer.zero_grad()

评估模型

227            output = self.model(features, edges_adj)

获得训练节点的损失

229            loss = self.loss_func(output[idx_train], labels[idx_train])

计算梯度

231            loss.backward()

采取优化步骤

233            self.optimizer.step()

记录损失

235            tracker.add('loss.train', loss)

记录准确性

237            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

将模式设置为评估模式以进行验证

240            self.model.eval()

无需计算梯度

243            with torch.no_grad():

再次评估模型

245                output = self.model(features, edges_adj)

计算验证节点的损失

247                loss = self.loss_func(output[idx_valid], labels[idx_valid])

记录损失

249                tracker.add('loss.valid', loss)

记录准确性

251                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

保存日志

254            tracker.save()

创建 Cora 数据集

257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262    return CoraDataset(c.include_edges)

获取班级数

266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))

输入中的要素数量

268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])

创建 GAT 模型

271@option(Configs.model)
272def gat_model(c: Configs):
276    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

创建可配置的优化器

279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284    opt_conf = OptimizerConfigs()
285    opt_conf.parameters = c.model.parameters()
286    return opt_conf
289def main():

创建配置

291    conf = Configs()

创建实验

293    experiment.create(name='gat')

计算配置。

295    experiment.configs(conf, {

Adam 优化器

297        'optimizer.optimizer': 'Adam',
298        'optimizer.learning_rate': 5e-3,
299        'optimizer.weight_decay': 5e-4,
300    })

开始观看实验

303    with experiment.start():

运行训练

305        conf.run()

309if __name__ == '__main__':
310    main()