Graph 注意力网络 v2 (Gatv2)

这是 PyTorch 对 Gatv2 运算符的实现,摘自《图注意力网络有多专心?

Gatv2 处理的图形数据与 GAT 类似。图由节点和连接节点的边组成。例如,在 Cora 数据集中,节点是研究论文,边缘是连接论文的引文。

Gatv2 操作员修复了标准 G AT 的静态注意力问题。静态注意力是指任何查询节点对关键节点的关注等级(顺序)相同。GAT 将从查询节点到关键节点的注意力计算为,

请注意,对于任何查询节点,键的注意力等级 () 仅取决于。因此,所有查询的键的注意力等级保持不变(静态)。

Gatv2 通过改变注意力机制来允许动态关注,

该论文表明,GAT的静态注意力机制在合成字典查找数据集的某些图形问题上会失败。这是一个完全连接的二分图,其中一组节点(查询节点)具有与之关联的密钥,而另一组节点既有键又有与之关联的值。目标是预测查询节点的值。GAT 无法完成此任务,因为其静态注意力有限。

以下是在 Cora 数据集上训练双层 Gatv2 的训练代码

57import torch
58from torch import nn
59
60from labml_helpers.module import Module

Graph 注意力 v2 层

这是单图关注 v2 层。GATv2 由多个这样的层组成。它需要,其中作为输入和输出,在哪里

63class GraphAttentionV2Layer(Module):
  • in_features ,是每个节点的输入要素数
  • out_features ,是每个节点的输出要素数
  • n_heads ,是注意头的数量
  • is_concat 多头结果应该是串联还是求平均值
  • dropout 是辍学概率
  • leaky_relu_negative_slope 是泄漏的 relu 激活的负斜率
  • share_weights 如果设置为True ,则同一矩阵将应用于每条边的源节点和目标节点
76    def __init__(self, in_features: int, out_features: int, n_heads: int,
77                 is_concat: bool = True,
78                 dropout: float = 0.6,
79                 leaky_relu_negative_slope: float = 0.2,
80                 share_weights: bool = False):
90        super().__init__()
91
92        self.is_concat = is_concat
93        self.n_heads = n_heads
94        self.share_weights = share_weights

计算每头的尺寸数

97        if is_concat:
98            assert out_features % n_heads == 0

如果我们要连接多个头

100            self.n_hidden = out_features // n_heads
101        else:

如果我们平均多头

103            self.n_hidden = out_features

用于初始源变换的线性层;即在自我关注之前转换源节点嵌入

107        self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

如果share_weightsTrue ,则为目标节点使用相同的线性层

109        if share_weights:
110            self.linear_r = self.linear_l
111        else:
112            self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

用于计算注意力分数的线性图层

114        self.attn = nn.Linear(self.n_hidden, 1, bias=False)

激活注意力分数

116        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

Softmax 需要计算注意力

118        self.softmax = nn.Softmax(dim=1)

要应用的掉落层以引起注意

120        self.dropout = nn.Dropout(dropout)
  • h是 shape 的输入节点嵌入[n_nodes, in_features]
  • adj_mat 是形状的邻接矩阵[n_nodes, n_nodes, n_heads] 。我们使用形状,[n_nodes, n_nodes, 1] 因为每个头部的邻接是相同的。邻接矩阵表示节点之间的边(或连接)。adj_mat[i][j] True 如果节点与节i 点之间存在边缘j
122    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

节点数量

132        n_nodes = h.shape[0]

每个头部的初始变换。我们做了两个线性变换,然后将其拆分为每个头部。

138        g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
139        g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)

计算注意力分数

我们为每个头部计算这些为简单起见,我们省略了

是从一个节点到另一个节点的注意力分数(重要性)。我们为每个头部计算这个值。

是计算注意力分数的注意力机制。本文求和然后是 a,然后使用权重向量进行线性变换

注意:本文描述的内容等同于我们在此处使用的定义。

首先,我们计算所有对.

g_l_repeat 获取每个节点嵌入重复n_nodes 次数的位置。

177        g_l_repeat = g_l.repeat(n_nodes, 1, 1)

g_r_repeat_interleave 获取每个节点嵌入重复n_nodes 次数的位置。

182        g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)

现在我们添加两个张量来获得

190        g_sum = g_l_repeat + g_r_repeat_interleave

重塑g_sum[i, j] 就是这样

192        g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)

计算e 是形状的[n_nodes, n_nodes, n_heads, 1]

200        e = self.attn(self.activation(g_sum))

移除大小的最后一个维度1

202        e = e.squeeze(-1)

邻接矩阵的形状应[n_nodes, n_nodes, n_heads][n_nodes, n_nodes, 1]

206        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
207        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
208        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

基于邻接矩阵的掩码。如果没有从到的边缘,则设置

211        e = e.masked_fill(adj_mat == 0, float('-inf'))

然后,我们将注意力分数(或系数)归一化

其中是连接到的节点集

我们通过将未连接的配对设置为未连接的配对来实现此目的。

221        a = self.softmax(e)

应用辍学正则化

224        a = self.dropout(a)

计算每个头的最终输出

228        attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)

连接头部

231        if self.is_concat:

233            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)

以头脑的意思为例

235        else:

237            return attn_res.mean(dim=1)