11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs
Cora 数据集是研究论文的数据集。对于每篇论文,我们都得到一个二进制特征向量,该向量表示单词的存在。每篇论文分为 7 个类别之一。该数据集还具有引文网络。
论文是图的节点,边缘是引文。
任务是使用特征向量和引文网络作为输入,将节点分类为 7 类。
26class CoraDataset:
每个节点的标签
41 labels: torch.Tensor
一组类名和一个唯一的整数索引
43 classes: Dict[str, int]
所有节点的特征向量
45 features: torch.Tensor
包含边信息的邻接矩阵。adj_mat[i][j]
True
如果存在从i
到的边缘j
。
48 adj_mat: torch.Tensor
下载数据集
50 @staticmethod
51 def _download():
55 if not (lab.get_data_path() / 'cora').exists():
56 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57 lab.get_data_path() / 'cora.tgz')
58 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
加载数据集
60 def __init__(self, include_edges: bool = True):
是否包括边缘。这是测试如果我们忽略引文网络会损失多少准确性。
67 self.include_edges = include_edges
下载数据集
70 self._download()
阅读纸张 ID、特征矢量和标签
73 with monit.section('Read content file'):
74 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
加载引文,这是一个整数对的列表。
76 with monit.section('Read citations file'):
77 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
获取特征向量
80 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
归一化特征向量
82 self.features = features / features.sum(dim=1, keepdim=True)
获取类名并为每个类分配一个唯一的整数
85 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
获取这些整数的标签
87 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
获取纸质证件
90 paper_ids = np.array(content[:, 0], dtype=np.int32)
纸张 ID 到索引的映射
92 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
空邻接矩阵-恒等矩阵
95 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
在邻接矩阵中标记引用
98 if self.include_edges:
99 for e in citations:
一对纸质索引
101 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
我们构建一个对称的图形,如果纸张引用了纸张,我们会在其中放置一个从到的徽章以及从到。
105 self.adj_mat[e1][e2] = True
106 self.adj_mat[e2][e1] = True
in_features
是每个节点的要素数n_hidden
是第一个图形关注层中的要素数n_classes
是类的数量n_heads
是图表关注层中的头部数量dropout
是辍学概率116 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124 super().__init__()
我们连接头部的第一个图形注意层
127 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
第一个图形关注层之后的激活功能
129 self.activation = nn.ELU()
最后一张图关注层,我们平均头部
131 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
辍学
133 self.dropout = nn.Dropout(dropout)
x
是形状的特征向量[n_nodes, in_features]
adj_mat
是形式的邻接矩阵[n_nodes, n_nodes, n_heads]
或[n_nodes, n_nodes, 1]
135 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
将丢失应用于输入
142 x = self.dropout(x)
第一个图形关注层
144 x = self.layer1(x, adj_mat)
激活功能
146 x = self.activation(x)
辍学
148 x = self.dropout(x)
logits 的输出层(未激活)
150 return self.output(x, adj_mat)
计算精度的简单函数
153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
160class Configs(BaseConfigs):
型号
166 model: GAT
要训练的节点数
168 training_samples: int = 500
输入中每个节点的要素数
170 in_features: int
第一个图形关注图层中的要素数
172 n_hidden: int = 64
头数
174 n_heads: int = 8
用于分类的类数
176 n_classes: int
辍学概率
178 dropout: float = 0.6
是否包括引文网络
180 include_edges: bool = True
数据集
182 dataset: CoraDataset
训练迭代次数
184 epochs: int = 1_000
亏损函数
186 loss_func = nn.CrossEntropyLoss()
191 device: torch.device = DeviceConfigs()
优化器
193 optimizer: torch.optim.Adam
195 def run(self):
将特征向量移动到设备
205 features = self.dataset.features.to(self.device)
将标签移到设备上
207 labels = self.dataset.labels.to(self.device)
将邻接矩阵移至设备
209 edges_adj = self.dataset.adj_mat.to(self.device)
为头部添加一个空的第三个维度
211 edges_adj = edges_adj.unsqueeze(-1)
随机索引
214 idx_rand = torch.randperm(len(labels))
训练节点
216 idx_train = idx_rand[:self.training_samples]
用于验证的节点
218 idx_valid = idx_rand[self.training_samples:]
训练循环
221 for epoch in monit.loop(self.epochs):
将模型设置为训练模式
223 self.model.train()
将所有渐变设为零
225 self.optimizer.zero_grad()
评估模型
227 output = self.model(features, edges_adj)
获得训练节点的损失
229 loss = self.loss_func(output[idx_train], labels[idx_train])
计算梯度
231 loss.backward()
采取优化步骤
233 self.optimizer.step()
记录损失
235 tracker.add('loss.train', loss)
记录准确性
237 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
将模式设置为评估模式以进行验证
240 self.model.eval()
无需计算梯度
243 with torch.no_grad():
再次评估模型
245 output = self.model(features, edges_adj)
计算验证节点的损失
247 loss = self.loss_func(output[idx_valid], labels[idx_valid])
记录损失
249 tracker.add('loss.valid', loss)
记录准确性
251 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
保存日志
254 tracker.save()
创建 Cora 数据集
257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262 return CoraDataset(c.include_edges)
获取班级数
266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
输入中的要素数量
268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
创建 GAT 模型
271@option(Configs.model)
272def gat_model(c: Configs):
276 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
创建可配置的优化器
279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284 opt_conf = OptimizerConfigs()
285 opt_conf.parameters = c.model.parameters()
286 return opt_conf
289def main():
创建配置
291 conf = Configs()
创建实验
293 experiment.create(name='gat')
计算配置。
295 experiment.configs(conf, {
Adam 优化器
297 'optimizer.optimizer': 'Adam',
298 'optimizer.learning_rate': 5e-3,
299 'optimizer.weight_decay': 5e-4,
300 })
开始观看实验
303 with experiment.start():
运行训练
305 conf.run()
309if __name__ == '__main__':
310 main()