13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18
19from labml import lab, monit, tracker, experiment
20from labml.configs import BaseConfigs, option, calculate
21from labml.utils import download
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.module import Module
24from labml_nn.graphs.gat import GraphAttentionLayer
25from labml_nn.optimizers.configs import OptimizerConfigs
Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.
The papers are the nodes of the graph and the edges are the citations.
The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.
28class CoraDataset:
Labels for each node
43 labels: torch.Tensor
Set of class names and an unique integer index
45 classes: Dict[str, int]
Feature vectors for all nodes
47 features: torch.Tensor
Adjacency matrix with the edge information. adj_mat[i][j]
is True
if there is an edge from i
to j
.
50 adj_mat: torch.Tensor
Download the dataset
52 @staticmethod
53 def _download():
57 if not (lab.get_data_path() / 'cora').exists():
58 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59 lab.get_data_path() / 'cora.tgz')
60 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
Load the dataset
62 def __init__(self, include_edges: bool = True):
Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.
69 self.include_edges = include_edges
Download dataset
72 self._download()
Read the paper ids, feature vectors, and labels
75 with monit.section('Read content file'):
76 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
Load the citations, it's a list of pairs of integers.
78 with monit.section('Read citations file'):
79 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
Get the feature vectors
82 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
Normalize the feature vectors
84 self.features = features / features.sum(dim=1, keepdim=True)
Get the class names and assign an unique integer to each of them
87 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
Get the labels as those integers
89 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
Get the paper ids
92 paper_ids = np.array(content[:, 0], dtype=np.int32)
Map of paper id to index
94 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
Empty adjacency matrix - an identity matrix
97 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
Mark the citations in the adjacency matrix
100 if self.include_edges:
101 for e in citations:
The pair of paper indexes
103 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .
107 self.adj_mat[e1][e2] = True
108 self.adj_mat[e2][e1] = True
111class GAT(Module):
in_features
is the number of features per node n_hidden
is the number of features in the first graph attention layer n_classes
is the number of classes n_heads
is the number of heads in the graph attention layers dropout
is the dropout probability118 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
126 super().__init__()
First graph attention layer where we concatenate the heads
129 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
Activation function after first graph attention layer
131 self.activation = nn.ELU()
Final graph attention layer where we average the heads
133 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
Dropout
135 self.dropout = nn.Dropout(dropout)
x
is the features vectors of shape [n_nodes, in_features]
adj_mat
is the adjacency matrix of the form [n_nodes, n_nodes, n_heads]
or [n_nodes, n_nodes, 1]
137 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
Apply dropout to the input
144 x = self.dropout(x)
First graph attention layer
146 x = self.layer1(x, adj_mat)
Activation function
148 x = self.activation(x)
Dropout
150 x = self.dropout(x)
Output layer (without activation) for logits
152 return self.output(x, adj_mat)
A simple function to calculate the accuracy
155def accuracy(output: torch.Tensor, labels: torch.Tensor):
159 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
162class Configs(BaseConfigs):
Model
168 model: GAT
Number of nodes to train on
170 training_samples: int = 500
Number of features per node in the input
172 in_features: int
Number of features in the first graph attention layer
174 n_hidden: int = 64
Number of heads
176 n_heads: int = 8
Number of classes for classification
178 n_classes: int
Dropout probability
180 dropout: float = 0.6
Whether to include the citation network
182 include_edges: bool = True
Dataset
184 dataset: CoraDataset
Number of training iterations
186 epochs: int = 1_000
Loss function
188 loss_func = nn.CrossEntropyLoss()
Device to train on
This creates configs for device, so that we can change the device by passing a config value
193 device: torch.device = DeviceConfigs()
Optimizer
195 optimizer: torch.optim.Adam
We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.
197 def run(self):
Move the feature vectors to the device
207 features = self.dataset.features.to(self.device)
Move the labels to the device
209 labels = self.dataset.labels.to(self.device)
Move the adjacency matrix to the device
211 edges_adj = self.dataset.adj_mat.to(self.device)
Add an empty third dimension for the heads
213 edges_adj = edges_adj.unsqueeze(-1)
Random indexes
216 idx_rand = torch.randperm(len(labels))
Nodes for training
218 idx_train = idx_rand[:self.training_samples]
Nodes for validation
220 idx_valid = idx_rand[self.training_samples:]
Training loop
223 for epoch in monit.loop(self.epochs):
Set the model to training mode
225 self.model.train()
Make all the gradients zero
227 self.optimizer.zero_grad()
Evaluate the model
229 output = self.model(features, edges_adj)
Get the loss for training nodes
231 loss = self.loss_func(output[idx_train], labels[idx_train])
Calculate gradients
233 loss.backward()
Take optimization step
235 self.optimizer.step()
Log the loss
237 tracker.add('loss.train', loss)
Log the accuracy
239 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
Set mode to evaluation mode for validation
242 self.model.eval()
No need to compute gradients
245 with torch.no_grad():
Evaluate the model again
247 output = self.model(features, edges_adj)
Calculate the loss for validation nodes
249 loss = self.loss_func(output[idx_valid], labels[idx_valid])
Log the loss
251 tracker.add('loss.valid', loss)
Log the accuracy
253 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
Save logs
256 tracker.save()
Create Cora dataset
259@option(Configs.dataset)
260def cora_dataset(c: Configs):
264 return CoraDataset(c.include_edges)
Get the number of classes
268calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
Number of features in the input
270calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
Create GAT model
273@option(Configs.model)
274def gat_model(c: Configs):
278 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
Create configurable optimizer
281@option(Configs.optimizer)
282def _optimizer(c: Configs):
286 opt_conf = OptimizerConfigs()
287 opt_conf.parameters = c.model.parameters()
288 return opt_conf
291def main():
Create configurations
293 conf = Configs()
Create an experiment
295 experiment.create(name='gat')
Calculate configurations.
297 experiment.configs(conf, {
Adam optimizer
299 'optimizer.optimizer': 'Adam',
300 'optimizer.learning_rate': 5e-3,
301 'optimizer.weight_decay': 5e-4,
302 })
Start and watch the experiment
305 with experiment.start():
Run the training
307 conf.run()
311if __name__ == '__main__':
312 main()