Train a Graph Attention Network (GAT) on Cora dataset

11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs

Cora Dataset

Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.

The papers are the nodes of the graph and the edges are the citations.

The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.

26class CoraDataset:

Labels for each node

41    labels: torch.Tensor

Set of class names and an unique integer index

43    classes: Dict[str, int]

Feature vectors for all nodes

45    features: torch.Tensor

Adjacency matrix with the edge information. adj_mat[i][j] is True if there is an edge from i to j .

48    adj_mat: torch.Tensor

Download the dataset

50    @staticmethod
51    def _download():
55        if not (lab.get_data_path() / 'cora').exists():
56            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57                                   lab.get_data_path() / 'cora.tgz')
58            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())

Load the dataset

60    def __init__(self, include_edges: bool = True):

Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.

67        self.include_edges = include_edges

Download dataset

70        self._download()

Read the paper ids, feature vectors, and labels

73        with monit.section('Read content file'):
74            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))

Load the citations, it's a list of pairs of integers.

76        with monit.section('Read citations file'):
77            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)

Get the feature vectors

80        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))

Normalize the feature vectors

82        self.features = features / features.sum(dim=1, keepdim=True)

Get the class names and assign an unique integer to each of them

85        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}

Get the labels as those integers

87        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)

Get the paper ids

90        paper_ids = np.array(content[:, 0], dtype=np.int32)

Map of paper id to index

92        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}

Empty adjacency matrix - an identity matrix

95        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)

Mark the citations in the adjacency matrix

98        if self.include_edges:
99            for e in citations:

The pair of paper indexes

101                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]

We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .

105                self.adj_mat[e1][e2] = True
106                self.adj_mat[e2][e1] = True

Graph Attention Network (GAT)

This graph attention network has two graph attention layers.

109class GAT(Module):
  • in_features is the number of features per node
  • n_hidden is the number of features in the first graph attention layer
  • n_classes is the number of classes
  • n_heads is the number of heads in the graph attention layers
  • dropout is the dropout probability
116    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124        super().__init__()

First graph attention layer where we concatenate the heads

127        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)

Activation function after first graph attention layer

129        self.activation = nn.ELU()

Final graph attention layer where we average the heads

131        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)

Dropout

133        self.dropout = nn.Dropout(dropout)
  • x is the features vectors of shape [n_nodes, in_features]
  • adj_mat is the adjacency matrix of the form [n_nodes, n_nodes, n_heads] or [n_nodes, n_nodes, 1]
135    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

Apply dropout to the input

142        x = self.dropout(x)

First graph attention layer

144        x = self.layer1(x, adj_mat)

Activation function

146        x = self.activation(x)

Dropout

148        x = self.dropout(x)

Output layer (without activation) for logits

150        return self.output(x, adj_mat)

A simple function to calculate the accuracy

153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)

Configurations

160class Configs(BaseConfigs):

Model

166    model: GAT

Number of nodes to train on

168    training_samples: int = 500

Number of features per node in the input

170    in_features: int

Number of features in the first graph attention layer

172    n_hidden: int = 64

Number of heads

174    n_heads: int = 8

Number of classes for classification

176    n_classes: int

Dropout probability

178    dropout: float = 0.6

Whether to include the citation network

180    include_edges: bool = True

Dataset

182    dataset: CoraDataset

Number of training iterations

184    epochs: int = 1_000

Loss function

186    loss_func = nn.CrossEntropyLoss()

Device to train on

This creates configs for device, so that we can change the device by passing a config value

191    device: torch.device = DeviceConfigs()

Optimizer

193    optimizer: torch.optim.Adam

Training loop

We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.

195    def run(self):

Move the feature vectors to the device

205        features = self.dataset.features.to(self.device)

Move the labels to the device

207        labels = self.dataset.labels.to(self.device)

Move the adjacency matrix to the device

209        edges_adj = self.dataset.adj_mat.to(self.device)

Add an empty third dimension for the heads

211        edges_adj = edges_adj.unsqueeze(-1)

Random indexes

214        idx_rand = torch.randperm(len(labels))

Nodes for training

216        idx_train = idx_rand[:self.training_samples]

Nodes for validation

218        idx_valid = idx_rand[self.training_samples:]

Training loop

221        for epoch in monit.loop(self.epochs):

Set the model to training mode

223            self.model.train()

Make all the gradients zero

225            self.optimizer.zero_grad()

Evaluate the model

227            output = self.model(features, edges_adj)

Get the loss for training nodes

229            loss = self.loss_func(output[idx_train], labels[idx_train])

Calculate gradients

231            loss.backward()

Take optimization step

233            self.optimizer.step()

Log the loss

235            tracker.add('loss.train', loss)

Log the accuracy

237            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

Set mode to evaluation mode for validation

240            self.model.eval()

No need to compute gradients

243            with torch.no_grad():

Evaluate the model again

245                output = self.model(features, edges_adj)

Calculate the loss for validation nodes

247                loss = self.loss_func(output[idx_valid], labels[idx_valid])

Log the loss

249                tracker.add('loss.valid', loss)

Log the accuracy

251                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

Save logs

254            tracker.save()

Create Cora dataset

257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262    return CoraDataset(c.include_edges)

Get the number of classes

266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))

Number of features in the input

268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])

Create GAT model

271@option(Configs.model)
272def gat_model(c: Configs):
276    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

Create configurable optimizer

279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284    opt_conf = OptimizerConfigs()
285    opt_conf.parameters = c.model.parameters()
286    return opt_conf
289def main():

Create configurations

291    conf = Configs()

Create an experiment

293    experiment.create(name='gat')

Calculate configurations.

295    experiment.configs(conf, {

Adam optimizer

297        'optimizer.optimizer': 'Adam',
298        'optimizer.learning_rate': 5e-3,
299        'optimizer.weight_decay': 5e-4,
300    })

Start and watch the experiment

303    with experiment.start():

Run the training

305        conf.run()

309if __name__ == '__main__':
310    main()