11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs
Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.
The papers are the nodes of the graph and the edges are the citations.
The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.
26class CoraDataset:
Labels for each node
41 labels: torch.Tensor
Set of class names and an unique integer index
43 classes: Dict[str, int]
Feature vectors for all nodes
45 features: torch.Tensor
Adjacency matrix with the edge information. adj_mat[i][j]
is True
if there is an edge from i
to j
.
48 adj_mat: torch.Tensor
Download the dataset
50 @staticmethod
51 def _download():
55 if not (lab.get_data_path() / 'cora').exists():
56 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57 lab.get_data_path() / 'cora.tgz')
58 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
Load the dataset
60 def __init__(self, include_edges: bool = True):
Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.
67 self.include_edges = include_edges
Download dataset
70 self._download()
Read the paper ids, feature vectors, and labels
73 with monit.section('Read content file'):
74 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
Load the citations, it's a list of pairs of integers.
76 with monit.section('Read citations file'):
77 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
Get the feature vectors
80 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
Normalize the feature vectors
82 self.features = features / features.sum(dim=1, keepdim=True)
Get the class names and assign an unique integer to each of them
85 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
Get the labels as those integers
87 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
Get the paper ids
90 paper_ids = np.array(content[:, 0], dtype=np.int32)
Map of paper id to index
92 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
Empty adjacency matrix - an identity matrix
95 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
Mark the citations in the adjacency matrix
98 if self.include_edges:
99 for e in citations:
The pair of paper indexes
101 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .
105 self.adj_mat[e1][e2] = True
106 self.adj_mat[e2][e1] = True
109class GAT(Module):
in_features
is the number of features per node n_hidden
is the number of features in the first graph attention layer n_classes
is the number of classes n_heads
is the number of heads in the graph attention layers dropout
is the dropout probability116 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124 super().__init__()
First graph attention layer where we concatenate the heads
127 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
Activation function after first graph attention layer
129 self.activation = nn.ELU()
Final graph attention layer where we average the heads
131 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
Dropout
133 self.dropout = nn.Dropout(dropout)
x
is the features vectors of shape [n_nodes, in_features]
adj_mat
is the adjacency matrix of the form [n_nodes, n_nodes, n_heads]
or [n_nodes, n_nodes, 1]
135 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
Apply dropout to the input
142 x = self.dropout(x)
First graph attention layer
144 x = self.layer1(x, adj_mat)
Activation function
146 x = self.activation(x)
Dropout
148 x = self.dropout(x)
Output layer (without activation) for logits
150 return self.output(x, adj_mat)
A simple function to calculate the accuracy
153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
160class Configs(BaseConfigs):
Model
166 model: GAT
Number of nodes to train on
168 training_samples: int = 500
Number of features per node in the input
170 in_features: int
Number of features in the first graph attention layer
172 n_hidden: int = 64
Number of heads
174 n_heads: int = 8
Number of classes for classification
176 n_classes: int
Dropout probability
178 dropout: float = 0.6
Whether to include the citation network
180 include_edges: bool = True
Dataset
182 dataset: CoraDataset
Number of training iterations
184 epochs: int = 1_000
Loss function
186 loss_func = nn.CrossEntropyLoss()
Device to train on
This creates configs for device, so that we can change the device by passing a config value
191 device: torch.device = DeviceConfigs()
Optimizer
193 optimizer: torch.optim.Adam
We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.
195 def run(self):
Move the feature vectors to the device
205 features = self.dataset.features.to(self.device)
Move the labels to the device
207 labels = self.dataset.labels.to(self.device)
Move the adjacency matrix to the device
209 edges_adj = self.dataset.adj_mat.to(self.device)
Add an empty third dimension for the heads
211 edges_adj = edges_adj.unsqueeze(-1)
Random indexes
214 idx_rand = torch.randperm(len(labels))
Nodes for training
216 idx_train = idx_rand[:self.training_samples]
Nodes for validation
218 idx_valid = idx_rand[self.training_samples:]
Training loop
221 for epoch in monit.loop(self.epochs):
Set the model to training mode
223 self.model.train()
Make all the gradients zero
225 self.optimizer.zero_grad()
Evaluate the model
227 output = self.model(features, edges_adj)
Get the loss for training nodes
229 loss = self.loss_func(output[idx_train], labels[idx_train])
Calculate gradients
231 loss.backward()
Take optimization step
233 self.optimizer.step()
Log the loss
235 tracker.add('loss.train', loss)
Log the accuracy
237 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
Set mode to evaluation mode for validation
240 self.model.eval()
No need to compute gradients
243 with torch.no_grad():
Evaluate the model again
245 output = self.model(features, edges_adj)
Calculate the loss for validation nodes
247 loss = self.loss_func(output[idx_valid], labels[idx_valid])
Log the loss
249 tracker.add('loss.valid', loss)
Log the accuracy
251 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
Save logs
254 tracker.save()
Create Cora dataset
257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262 return CoraDataset(c.include_edges)
Get the number of classes
266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
Number of features in the input
268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
Create GAT model
271@option(Configs.model)
272def gat_model(c: Configs):
276 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
Create configurable optimizer
279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284 opt_conf = OptimizerConfigs()
285 opt_conf.parameters = c.model.parameters()
286 return opt_conf
289def main():
Create configurations
291 conf = Configs()
Create an experiment
293 experiment.create(name='gat')
Calculate configurations.
295 experiment.configs(conf, {
Adam optimizer
297 'optimizer.optimizer': 'Adam',
298 'optimizer.learning_rate': 5e-3,
299 'optimizer.weight_decay': 5e-4,
300 })
Start and watch the experiment
303 with experiment.start():
Run the training
305 conf.run()
309if __name__ == '__main__':
310 main()