11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_nn.helpers.device import DeviceConfigs
21from labml_nn.graphs.gat import GraphAttentionLayer
22from labml_nn.optimizers.configs import OptimizerConfigs
Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.
The papers are the nodes of the graph and the edges are the citations.
The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.
25class CoraDataset:
Labels for each node
40 labels: torch.Tensor
Set of class names and an unique integer index
42 classes: Dict[str, int]
Feature vectors for all nodes
44 features: torch.Tensor
Adjacency matrix with the edge information. adj_mat[i][j]
is True
if there is an edge from i
to j
.
47 adj_mat: torch.Tensor
Download the dataset
49 @staticmethod
50 def _download():
54 if not (lab.get_data_path() / 'cora').exists():
55 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
56 lab.get_data_path() / 'cora.tgz')
57 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
Load the dataset
59 def __init__(self, include_edges: bool = True):
Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.
66 self.include_edges = include_edges
Download dataset
69 self._download()
Read the paper ids, feature vectors, and labels
72 with monit.section('Read content file'):
73 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
Load the citations, it's a list of pairs of integers.
75 with monit.section('Read citations file'):
76 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
Get the feature vectors
79 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
Normalize the feature vectors
81 self.features = features / features.sum(dim=1, keepdim=True)
Get the class names and assign an unique integer to each of them
84 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
Get the labels as those integers
86 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
Get the paper ids
89 paper_ids = np.array(content[:, 0], dtype=np.int32)
Map of paper id to index
91 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
Empty adjacency matrix - an identity matrix
94 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
Mark the citations in the adjacency matrix
97 if self.include_edges:
98 for e in citations:
The pair of paper indexes
100 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .
104 self.adj_mat[e1][e2] = True
105 self.adj_mat[e2][e1] = True
108class GAT(nn.Module):
in_features
is the number of features per node n_hidden
is the number of features in the first graph attention layer n_classes
is the number of classes n_heads
is the number of heads in the graph attention layers dropout
is the dropout probability115 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
123 super().__init__()
First graph attention layer where we concatenate the heads
126 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
Activation function after first graph attention layer
128 self.activation = nn.ELU()
Final graph attention layer where we average the heads
130 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
Dropout
132 self.dropout = nn.Dropout(dropout)
x
is the features vectors of shape [n_nodes, in_features]
adj_mat
is the adjacency matrix of the form [n_nodes, n_nodes, n_heads]
or [n_nodes, n_nodes, 1]
134 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
Apply dropout to the input
141 x = self.dropout(x)
First graph attention layer
143 x = self.layer1(x, adj_mat)
Activation function
145 x = self.activation(x)
Dropout
147 x = self.dropout(x)
Output layer (without activation) for logits
149 return self.output(x, adj_mat)
A simple function to calculate the accuracy
152def accuracy(output: torch.Tensor, labels: torch.Tensor):
156 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
159class Configs(BaseConfigs):
Model
165 model: GAT
Number of nodes to train on
167 training_samples: int = 500
Number of features per node in the input
169 in_features: int
Number of features in the first graph attention layer
171 n_hidden: int = 64
Number of heads
173 n_heads: int = 8
Number of classes for classification
175 n_classes: int
Dropout probability
177 dropout: float = 0.6
Whether to include the citation network
179 include_edges: bool = True
Dataset
181 dataset: CoraDataset
Number of training iterations
183 epochs: int = 1_000
Loss function
185 loss_func = nn.CrossEntropyLoss()
Device to train on
This creates configs for device, so that we can change the device by passing a config value
190 device: torch.device = DeviceConfigs()
Optimizer
192 optimizer: torch.optim.Adam
We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.
194 def run(self):
Move the feature vectors to the device
204 features = self.dataset.features.to(self.device)
Move the labels to the device
206 labels = self.dataset.labels.to(self.device)
Move the adjacency matrix to the device
208 edges_adj = self.dataset.adj_mat.to(self.device)
Add an empty third dimension for the heads
210 edges_adj = edges_adj.unsqueeze(-1)
Random indexes
213 idx_rand = torch.randperm(len(labels))
Nodes for training
215 idx_train = idx_rand[:self.training_samples]
Nodes for validation
217 idx_valid = idx_rand[self.training_samples:]
Training loop
220 for epoch in monit.loop(self.epochs):
Set the model to training mode
222 self.model.train()
Make all the gradients zero
224 self.optimizer.zero_grad()
Evaluate the model
226 output = self.model(features, edges_adj)
Get the loss for training nodes
228 loss = self.loss_func(output[idx_train], labels[idx_train])
Calculate gradients
230 loss.backward()
Take optimization step
232 self.optimizer.step()
Log the loss
234 tracker.add('loss.train', loss)
Log the accuracy
236 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
Set mode to evaluation mode for validation
239 self.model.eval()
No need to compute gradients
242 with torch.no_grad():
Evaluate the model again
244 output = self.model(features, edges_adj)
Calculate the loss for validation nodes
246 loss = self.loss_func(output[idx_valid], labels[idx_valid])
Log the loss
248 tracker.add('loss.valid', loss)
Log the accuracy
250 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
Save logs
253 tracker.save()
Create Cora dataset
256@option(Configs.dataset)
257def cora_dataset(c: Configs):
261 return CoraDataset(c.include_edges)
Get the number of classes
265calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
Number of features in the input
267calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
Create GAT model
270@option(Configs.model)
271def gat_model(c: Configs):
275 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
Create configurable optimizer
278@option(Configs.optimizer)
279def _optimizer(c: Configs):
283 opt_conf = OptimizerConfigs()
284 opt_conf.parameters = c.model.parameters()
285 return opt_conf
288def main():
Create configurations
290 conf = Configs()
Create an experiment
292 experiment.create(name='gat')
Calculate configurations.
294 experiment.configs(conf, {
Adam optimizer
296 'optimizer.optimizer': 'Adam',
297 'optimizer.learning_rate': 5e-3,
298 'optimizer.weight_decay': 5e-4,
299 })
Start and watch the experiment
302 with experiment.start():
Run the training
304 conf.run()
308if __name__ == '__main__':
309 main()