Graph Attention Networks (GAT)

This is a PyTorch implementation of the paper Graph Attention Networks.

GATs work on graph data. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.

GAT uses masked self-attention, kind of similar to transformers. GAT consists of graph attention layers stacked on top of each other. Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings. The node embeddings pay attention to the embeddings of other nodes it's connected to. The details of graph attention layers are included alongside the implementation.

Here is the training code for training a two-layer GAT on Cora dataset.

View Run

30import torch
31from torch import nn
32
33from labml_helpers.module import Module

Graph attention layer

This is a single graph attention layer. A GAT is made up of multiple such layers.

It takes , where as input and outputs , where .

36class GraphAttentionLayer(Module):
  • in_features , , is the number of input features per node
  • out_features , , is the number of output features per node
  • n_heads , , is the number of attention heads
  • is_concat whether the multi-head results should be concatenated or averaged
  • dropout is the dropout probability
  • leaky_relu_negative_slope is the negative slope for leaky relu activation
50    def __init__(self, in_features: int, out_features: int, n_heads: int,
51                 is_concat: bool = True,
52                 dropout: float = 0.6,
53                 leaky_relu_negative_slope: float = 0.2):
62        super().__init__()
63
64        self.is_concat = is_concat
65        self.n_heads = n_heads

Calculate the number of dimensions per head

68        if is_concat:
69            assert out_features % n_heads == 0

If we are concatenating the multiple heads

71            self.n_hidden = out_features // n_heads
72        else:

If we are averaging the multiple heads

74            self.n_hidden = out_features

Linear layer for initial transformation; i.e. to transform the node embeddings before self-attention

78        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

Linear layer to compute attention score

80        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)

The activation for attention score

82        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

Softmax to compute attention

84        self.softmax = nn.Softmax(dim=1)

Dropout layer to be applied for attention

86        self.dropout = nn.Dropout(dropout)
  • h , is the input node embeddings of shape [n_nodes, in_features] .
  • adj_mat is the adjacency matrix of shape [n_nodes, n_nodes, n_heads] . We use shape [n_nodes, n_nodes, 1] since the adjacency is the same for each head.

Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j] is True if there is an edge from node i to node j .

88    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

Number of nodes

99        n_nodes = h.shape[0]

The initial transformation, for each head. We do single linear transformation and then split it up for each head.

104        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

Calculate attention score

We calculate these for each head . We have omitted for simplicity.

is the attention score (importance) from node to node . We calculate this for each head.

is the attention mechanism, that calculates the attention score. The paper concatenates , and does a linear transformation with a weight vector followed by a .

First we calculate for all pairs of .

g_repeat gets where each node embedding is repeated n_nodes times.

135        g_repeat = g.repeat(n_nodes, 1, 1)

g_repeat_interleave gets where each node embedding is repeated n_nodes times.

140        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)

Now we concatenate to get

148        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)

Reshape so that g_concat[i, j] is

150        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)

Calculate e is of shape [n_nodes, n_nodes, n_heads, 1]

158        e = self.activation(self.attn(g_concat))

Remove the last dimension of size 1

160        e = e.squeeze(-1)

The adjacency matrix should have shape [n_nodes, n_nodes, n_heads] or[n_nodes, n_nodes, 1]

164        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
165        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
166        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

Mask based on adjacency matrix. is set to if there is no edge from to .

169        e = e.masked_fill(adj_mat == 0, float('-inf'))

We then normalize attention scores (or coefficients)

where is the set of nodes connected to .

We do this by setting unconnected to which makes for unconnected pairs.

179        a = self.softmax(e)

Apply dropout regularization

182        a = self.dropout(a)

Calculate final output for each head

Note: The paper includes the final activation in We have omitted this from the Graph Attention Layer implementation and use it on the GAT model to match with how other PyTorch modules are defined - activation as a separate layer.

191        attn_res = torch.einsum('ijh,jhf->ihf', a, g)

Concatenate the heads

194        if self.is_concat:

196            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)

Take the mean of the heads

198        else:

200            return attn_res.mean(dim=1)