Train a Graph Attention Network (GAT) on Cora dataset

View Run

13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18
19from labml import lab, monit, tracker, experiment
20from labml.configs import BaseConfigs, option, calculate
21from labml.utils import download
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.module import Module
24from labml_nn.graphs.gat import GraphAttentionLayer
25from labml_nn.optimizers.configs import OptimizerConfigs

Cora Dataset

Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.

The papers are the nodes of the graph and the edges are the citations.

The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.

28class CoraDataset:

Labels for each node

43    labels: torch.Tensor

Set of class names and an unique integer index

45    classes: Dict[str, int]

Feature vectors for all nodes

47    features: torch.Tensor

Adjacency matrix with the edge information. adj_mat[i][j] is True if there is an edge from i to j .

50    adj_mat: torch.Tensor

Download the dataset

52    @staticmethod
53    def _download():
57        if not (lab.get_data_path() / 'cora').exists():
58            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59                                   lab.get_data_path() / 'cora.tgz')
60            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())

Load the dataset

62    def __init__(self, include_edges: bool = True):

Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.

69        self.include_edges = include_edges

Download dataset

72        self._download()

Read the paper ids, feature vectors, and labels

75        with monit.section('Read content file'):
76            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))

Load the citations, it's a list of pairs of integers.

78        with monit.section('Read citations file'):
79            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)

Get the feature vectors

82        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))

Normalize the feature vectors

84        self.features = features / features.sum(dim=1, keepdim=True)

Get the class names and assign an unique integer to each of them

87        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}

Get the labels as those integers

89        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)

Get the paper ids

92        paper_ids = np.array(content[:, 0], dtype=np.int32)

Map of paper id to index

94        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}

Empty adjacency matrix - an identity matrix

97        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)

Mark the citations in the adjacency matrix

100        if self.include_edges:
101            for e in citations:

The pair of paper indexes

103                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]

We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .

107                self.adj_mat[e1][e2] = True
108                self.adj_mat[e2][e1] = True

Graph Attention Network (GAT)

This graph attention network has two graph attention layers.

111class GAT(Module):
  • in_features is the number of features per node
  • n_hidden is the number of features in the first graph attention layer
  • n_classes is the number of classes
  • n_heads is the number of heads in the graph attention layers
  • dropout is the dropout probability
118    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
126        super().__init__()

First graph attention layer where we concatenate the heads

129        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)

Activation function after first graph attention layer

131        self.activation = nn.ELU()

Final graph attention layer where we average the heads

133        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)

Dropout

135        self.dropout = nn.Dropout(dropout)
  • x is the features vectors of shape [n_nodes, in_features]
  • adj_mat is the adjacency matrix of the form [n_nodes, n_nodes, n_heads] or [n_nodes, n_nodes, 1]
137    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

Apply dropout to the input

144        x = self.dropout(x)

First graph attention layer

146        x = self.layer1(x, adj_mat)

Activation function

148        x = self.activation(x)

Dropout

150        x = self.dropout(x)

Output layer (without activation) for logits

152        return self.output(x, adj_mat)

A simple function to calculate the accuracy

155def accuracy(output: torch.Tensor, labels: torch.Tensor):
159    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)

Configurations

162class Configs(BaseConfigs):

Model

168    model: GAT

Number of nodes to train on

170    training_samples: int = 500

Number of features per node in the input

172    in_features: int

Number of features in the first graph attention layer

174    n_hidden: int = 64

Number of heads

176    n_heads: int = 8

Number of classes for classification

178    n_classes: int

Dropout probability

180    dropout: float = 0.6

Whether to include the citation network

182    include_edges: bool = True

Dataset

184    dataset: CoraDataset

Number of training iterations

186    epochs: int = 1_000

Loss function

188    loss_func = nn.CrossEntropyLoss()

Device to train on

This creates configs for device, so that we can change the device by passing a config value

193    device: torch.device = DeviceConfigs()

Optimizer

195    optimizer: torch.optim.Adam

Training loop

We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.

197    def run(self):

Move the feature vectors to the device

207        features = self.dataset.features.to(self.device)

Move the labels to the device

209        labels = self.dataset.labels.to(self.device)

Move the adjacency matrix to the device

211        edges_adj = self.dataset.adj_mat.to(self.device)

Add an empty third dimension for the heads

213        edges_adj = edges_adj.unsqueeze(-1)

Random indexes

216        idx_rand = torch.randperm(len(labels))

Nodes for training

218        idx_train = idx_rand[:self.training_samples]

Nodes for validation

220        idx_valid = idx_rand[self.training_samples:]

Training loop

223        for epoch in monit.loop(self.epochs):

Set the model to training mode

225            self.model.train()

Make all the gradients zero

227            self.optimizer.zero_grad()

Evaluate the model

229            output = self.model(features, edges_adj)

Get the loss for training nodes

231            loss = self.loss_func(output[idx_train], labels[idx_train])

Calculate gradients

233            loss.backward()

Take optimization step

235            self.optimizer.step()

Log the loss

237            tracker.add('loss.train', loss)

Log the accuracy

239            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

Set mode to evaluation mode for validation

242            self.model.eval()

No need to compute gradients

245            with torch.no_grad():

Evaluate the model again

247                output = self.model(features, edges_adj)

Calculate the loss for validation nodes

249                loss = self.loss_func(output[idx_valid], labels[idx_valid])

Log the loss

251                tracker.add('loss.valid', loss)

Log the accuracy

253                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

Save logs

256            tracker.save()

Create Cora dataset

259@option(Configs.dataset)
260def cora_dataset(c: Configs):
264    return CoraDataset(c.include_edges)

Get the number of classes

268calculate(Configs.n_classes, lambda c: len(c.dataset.classes))

Number of features in the input

270calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])

Create GAT model

273@option(Configs.model)
274def gat_model(c: Configs):
278    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

Create configurable optimizer

281@option(Configs.optimizer)
282def _optimizer(c: Configs):
286    opt_conf = OptimizerConfigs()
287    opt_conf.parameters = c.model.parameters()
288    return opt_conf
291def main():

Create configurations

293    conf = Configs()

Create an experiment

295    experiment.create(name='gat')

Calculate configurations.

297    experiment.configs(conf, {

Adam optimizer

299        'optimizer.optimizer': 'Adam',
300        'optimizer.learning_rate': 5e-3,
301        'optimizer.weight_decay': 5e-4,
302    })

Start and watch the experiment

305    with experiment.start():

Run the training

307        conf.run()

311if __name__ == '__main__':
312    main()