This is a PyTorch implementation of the paper Graph Attention Networks.
GATs work on graph data. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.
GAT uses masked self-attention, kind of similar to transformers. GAT consists of graph attention layers stacked on top of each other. Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings. The node embeddings pay attention to the embeddings of other nodes it's connected to. The details of graph attention layers are included alongside the implementation.
Here is the training code for training a two-layer GAT on Cora dataset.
30import torch 31from torch import nn 32 33from labml_helpers.module import Module
This is a single graph attention layer. A GAT is made up of multiple such layers.
It takes , where as input and outputs , where .
in_features, , is the number of input features per node
out_features, , is the number of output features per node
n_heads, , is the number of attention heads
is_concatwhether the multi-head results should be concatenated or averaged
dropoutis the dropout probability
leaky_relu_negative_slopeis the negative slope for leaky relu activation
50 def __init__(self, in_features: int, out_features: int, n_heads: int, 51 is_concat: bool = True, 52 dropout: float = 0.6, 53 leaky_relu_negative_slope: float = 0.2):
62 super().__init__() 63 64 self.is_concat = is_concat 65 self.n_heads = n_heads
Calculate the number of dimensions per head
68 if is_concat: 69 assert out_features % n_heads == 0
If we are concatenating the multiple heads
71 self.n_hidden = out_features // n_heads 72 else:
If we are averaging the multiple heads
74 self.n_hidden = out_features
Linear layer for initial transformation; i.e. to transform the node embeddings before self-attention
78 self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
Linear layer to compute attention score
80 self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
The activation for attention score
82 self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
Softmax to compute attention
84 self.softmax = nn.Softmax(dim=1)
Dropout layer to be applied for attention
86 self.dropout = nn.Dropout(dropout)
h, is the input node embeddings of shape
adj_matis the adjacency matrix of shape
[n_nodes, n_nodes, n_heads]. We use shape
[n_nodes, n_nodes, 1]since the adjacency is the same for each head.
Adjacency matrix represent the edges (or connections) among nodes.
if there is an edge from node
88 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
Number of nodes
99 n_nodes = h.shape
The initial transformation, for each head. We do single linear transformation and then split it up for each head.
104 g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
We calculate these for each head . We have omitted for simplicity.
is the attention score (importance) from node to node . We calculate this for each head.
is the attention mechanism, that calculates the attention score. The paper concatenates , and does a linear transformation with a weight vector followed by a .
First we calculate for all pairs of .
gets where each node embedding is repeated
135 g_repeat = g.repeat(n_nodes, 1, 1)
gets where each node embedding is repeated
140 g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
Now we concatenate to get
148 g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
Reshape so that
150 g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
is of shape
[n_nodes, n_nodes, n_heads, 1]
158 e = self.activation(self.attn(g_concat))
Remove the last dimension of size
160 e = e.squeeze(-1)
The adjacency matrix should have shape
[n_nodes, n_nodes, n_heads]
[n_nodes, n_nodes, 1]
164 assert adj_mat.shape == 1 or adj_mat.shape == n_nodes 165 assert adj_mat.shape == 1 or adj_mat.shape == n_nodes 166 assert adj_mat.shape == 1 or adj_mat.shape == self.n_heads
Mask based on adjacency matrix. is set to if there is no edge from to .
169 e = e.masked_fill(adj_mat == 0, float('-inf'))
We then normalize attention scores (or coefficients)
where is the set of nodes connected to .
We do this by setting unconnected to which makes for unconnected pairs.
179 a = self.softmax(e)
Apply dropout regularization
182 a = self.dropout(a)
Calculate final output for each head
Note: The paper includes the final activation in We have omitted this from the Graph Attention Layer implementation and use it on the GAT model to match with how other PyTorch modules are defined - activation as a separate layer.
191 attn_res = torch.einsum('ijh,jhf->ihf', a, g)
Concatenate the heads
194 if self.is_concat:
196 return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
Take the mean of the heads
200 return attn_res.mean(dim=1)