Train a Graph Attention Network (GAT) on Cora dataset

11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_nn.helpers.device import DeviceConfigs
21from labml_nn.graphs.gat import GraphAttentionLayer
22from labml_nn.optimizers.configs import OptimizerConfigs

Cora Dataset

Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.

The papers are the nodes of the graph and the edges are the citations.

The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.

25class CoraDataset:

Labels for each node

40    labels: torch.Tensor

Set of class names and an unique integer index

42    classes: Dict[str, int]

Feature vectors for all nodes

44    features: torch.Tensor

Adjacency matrix with the edge information. adj_mat[i][j] is True if there is an edge from i to j .

47    adj_mat: torch.Tensor

Download the dataset

49    @staticmethod
50    def _download():
54        if not (lab.get_data_path() / 'cora').exists():
55            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
56                                   lab.get_data_path() / 'cora.tgz')
57            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())

Load the dataset

59    def __init__(self, include_edges: bool = True):

Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.

66        self.include_edges = include_edges

Download dataset

69        self._download()

Read the paper ids, feature vectors, and labels

72        with monit.section('Read content file'):
73            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))

Load the citations, it's a list of pairs of integers.

75        with monit.section('Read citations file'):
76            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)

Get the feature vectors

79        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))

Normalize the feature vectors

81        self.features = features / features.sum(dim=1, keepdim=True)

Get the class names and assign an unique integer to each of them

84        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}

Get the labels as those integers

86        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)

Get the paper ids

89        paper_ids = np.array(content[:, 0], dtype=np.int32)

Map of paper id to index

91        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}

Empty adjacency matrix - an identity matrix

94        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)

Mark the citations in the adjacency matrix

97        if self.include_edges:
98            for e in citations:

The pair of paper indexes

100                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]

We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .

104                self.adj_mat[e1][e2] = True
105                self.adj_mat[e2][e1] = True

Graph Attention Network (GAT)

This graph attention network has two graph attention layers.

108class GAT(nn.Module):
  • in_features is the number of features per node
  • n_hidden is the number of features in the first graph attention layer
  • n_classes is the number of classes
  • n_heads is the number of heads in the graph attention layers
  • dropout is the dropout probability
115    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
123        super().__init__()

First graph attention layer where we concatenate the heads

126        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)

Activation function after first graph attention layer

128        self.activation = nn.ELU()

Final graph attention layer where we average the heads

130        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)

Dropout

132        self.dropout = nn.Dropout(dropout)
  • x is the features vectors of shape [n_nodes, in_features]
  • adj_mat is the adjacency matrix of the form [n_nodes, n_nodes, n_heads] or [n_nodes, n_nodes, 1]
134    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

Apply dropout to the input

141        x = self.dropout(x)

First graph attention layer

143        x = self.layer1(x, adj_mat)

Activation function

145        x = self.activation(x)

Dropout

147        x = self.dropout(x)

Output layer (without activation) for logits

149        return self.output(x, adj_mat)

A simple function to calculate the accuracy

152def accuracy(output: torch.Tensor, labels: torch.Tensor):
156    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)

Configurations

159class Configs(BaseConfigs):

Model

165    model: GAT

Number of nodes to train on

167    training_samples: int = 500

Number of features per node in the input

169    in_features: int

Number of features in the first graph attention layer

171    n_hidden: int = 64

Number of heads

173    n_heads: int = 8

Number of classes for classification

175    n_classes: int

Dropout probability

177    dropout: float = 0.6

Whether to include the citation network

179    include_edges: bool = True

Dataset

181    dataset: CoraDataset

Number of training iterations

183    epochs: int = 1_000

Loss function

185    loss_func = nn.CrossEntropyLoss()

Device to train on

This creates configs for device, so that we can change the device by passing a config value

190    device: torch.device = DeviceConfigs()

Optimizer

192    optimizer: torch.optim.Adam

Training loop

We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.

194    def run(self):

Move the feature vectors to the device

204        features = self.dataset.features.to(self.device)

Move the labels to the device

206        labels = self.dataset.labels.to(self.device)

Move the adjacency matrix to the device

208        edges_adj = self.dataset.adj_mat.to(self.device)

Add an empty third dimension for the heads

210        edges_adj = edges_adj.unsqueeze(-1)

Random indexes

213        idx_rand = torch.randperm(len(labels))

Nodes for training

215        idx_train = idx_rand[:self.training_samples]

Nodes for validation

217        idx_valid = idx_rand[self.training_samples:]

Training loop

220        for epoch in monit.loop(self.epochs):

Set the model to training mode

222            self.model.train()

Make all the gradients zero

224            self.optimizer.zero_grad()

Evaluate the model

226            output = self.model(features, edges_adj)

Get the loss for training nodes

228            loss = self.loss_func(output[idx_train], labels[idx_train])

Calculate gradients

230            loss.backward()

Take optimization step

232            self.optimizer.step()

Log the loss

234            tracker.add('loss.train', loss)

Log the accuracy

236            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

Set mode to evaluation mode for validation

239            self.model.eval()

No need to compute gradients

242            with torch.no_grad():

Evaluate the model again

244                output = self.model(features, edges_adj)

Calculate the loss for validation nodes

246                loss = self.loss_func(output[idx_valid], labels[idx_valid])

Log the loss

248                tracker.add('loss.valid', loss)

Log the accuracy

250                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

Save logs

253            tracker.save()

Create Cora dataset

256@option(Configs.dataset)
257def cora_dataset(c: Configs):
261    return CoraDataset(c.include_edges)

Get the number of classes

265calculate(Configs.n_classes, lambda c: len(c.dataset.classes))

Number of features in the input

267calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])

Create GAT model

270@option(Configs.model)
271def gat_model(c: Configs):
275    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

Create configurable optimizer

278@option(Configs.optimizer)
279def _optimizer(c: Configs):
283    opt_conf = OptimizerConfigs()
284    opt_conf.parameters = c.model.parameters()
285    return opt_conf
288def main():

Create configurations

290    conf = Configs()

Create an experiment

292    experiment.create(name='gat')

Calculate configurations.

294    experiment.configs(conf, {

Adam optimizer

296        'optimizer.optimizer': 'Adam',
297        'optimizer.learning_rate': 5e-3,
298        'optimizer.weight_decay': 5e-4,
299    })

Start and watch the experiment

302    with experiment.start():

Run the training

304        conf.run()

308if __name__ == '__main__':
309    main()