グラフ・アテンション・ネットワーク (GAT)

これは論文の「グラフ・アテンション・ネットワーク」の PyTorch 実装です。

GAT はグラフデータを処理します。グラフは、ノードとノードを接続するエッジで構成されます。たとえば、Coraデータセットでは、ノードは研究論文で、端は論文をつなぐ引用です

GATは、トランスフォーマーに似た、マスクされたセルフアテンションを使います。GATは、グラフアテンションレイヤーが互いに重なり合って構成されています。各グラフアテンションレイヤーは、入力としてノード埋め込みを取得し、変換された埋め込みを出力します。ノード埋め込みは、接続されている他のノードの埋め込みに注目します。グラフアテンションレイヤーの詳細は、実装とともに含まれています。

Cora データセットで 2 層 GAT をトレーニングするためのトレーニングコードを次に示します

28import torch
29from torch import nn
30
31from labml_helpers.module import Module

グラフ・アテンション・レイヤー

これは単一のグラフアテンションレイヤーです。GAT はこのような複数のレイヤーで構成されています

入力として、where を、出力として、where を取ります。

34class GraphAttentionLayer(Module):
  • in_features 、はノードあたりの入力フィーチャの数です
  • out_features 、はノードごとの出力フィーチャの数です
  • n_heads 、はアテンション・ヘッドの数
  • is_concat マルチヘッドの結果を連結すべきか平均化すべきか
  • dropout は脱落確率です
  • leaky_relu_negative_slope リークのあるリレーアクティベーションの負の傾きです
48    def __init__(self, in_features: int, out_features: int, n_heads: int,
49                 is_concat: bool = True,
50                 dropout: float = 0.6,
51                 leaky_relu_negative_slope: float = 0.2):
60        super().__init__()
61
62        self.is_concat = is_concat
63        self.n_heads = n_heads

頭あたりの寸法数の計算

66        if is_concat:
67            assert out_features % n_heads == 0

複数のヘッドを連結する場合

69            self.n_hidden = out_features // n_heads
70        else:

複数のヘッドを平均化する場合

72            self.n_hidden = out_features

初期変換用の線形レイヤー。つまり、自己処理の前にノード埋め込みを変換する

76        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

アテンションスコアを計算する線形レイヤー

78        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)

アテンションスコアのアクティベーション

80        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

注意力を計算するソフトマックス

82        self.softmax = nn.Softmax(dim=1)

注目すべきドロップアウト層

84        self.dropout = nn.Dropout(dropout)
  • hはシェイプの入力ノード埋め込みです。[n_nodes, in_features]
  • adj_mat [n_nodes, n_nodes, n_heads] は形状の隣接行列です。[n_nodes, n_nodes, 1] 各ヘッドの隣接関係が同じなので、形状を使用します

隣接マトリックスは、ノード間のエッジ (または接続) を表します。adj_mat[i][j] True i ノード間でエッジがある場合ですj

86    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

ノード数

97        n_nodes = h.shape[0]

各ヘッドの初期変形。線形変換を1つ行い、それを頭ごとに分割します。

102        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

アテンションスコアの計算

これらは頭ごとに計算します。わかりやすくするために省略しました。

ノードごとのアテンションスコア(重要度)です。これを頭ごとに計算します。

アテンションスコアを計算するアテンションメカニズムです。この論文では重みベクトルの後にaを連結し、線形変換を行います

まず、すべてのペアを計算します.

g_repeat n_nodes 各ノードの埋め込みが何度も繰り返される場所を取得します。

133        g_repeat = g.repeat(n_nodes, 1, 1)

g_repeat_interleave n_nodes 各ノードの埋め込みが何度も繰り返される場所を取得します。

138        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)

次に、連結して次のようになります。

146        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)

g_concat[i, j] そのように形を変えてください

148        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)

e 形状の計算 [n_nodes, n_nodes, n_heads, 1]

156        e = self.activation(self.attn(g_concat))

サイズの最後のディメンションを削除 1

158        e = e.squeeze(-1)

隣接マトリックスは、[n_nodes, n_nodes, n_heads] またはの形状でなければなりません [n_nodes, n_nodes, 1]

162        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
163        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
164        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

隣接マトリックスに基づくマスク。からまでのエッジがない場合は、に設定されます

167        e = e.masked_fill(adj_mat == 0, float('-inf'))

次に、アテンションスコア (または係数) を正規化します

は接続先のノードセットがどこにあるか

そのためには、「未接続」を「未接続」に設定することでペアが接続されていない状態になります

177        a = self.softmax(e)

ドロップアウト正則化を適用

180        a = self.dropout(a)

各ヘッドの最終出力を計算

注:このホワイトペーパーでは、Graph Attention Layer の実装からは省略し、他の PyTorch モジュールの定義に合わせて GAT モデルで使用しています。つまり、アクティベーションは別のレイヤーとして行われます。

189        attn_res = torch.einsum('ijh,jhf->ihf', a, g)

ヘッドを連結してください

192        if self.is_concat:

194            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)

頭の中を平均して

196        else:

198            return attn_res.mean(dim=1)