グラフ・アテンション・ネットワークス v2 (GATv2)

これは、「グラフアテンションネットワークはどの程度注意深いのか?」という論文のGATv2演算子をPyTorchで実装したものです

GATv2は、GATと同様にグラフデータを処理します。グラフは、ノードとノードを接続するエッジで構成されます。たとえば、Coraデータセットでは、ノードは研究論文で、端は論文をつなぐ引用です

GATv2 オペレータは、標準 GAT のスタティックアテンションの問題を解決します。スタティックアテンションとは、どのクエリノードでもキーノードへのアテンションのランク(順序)が同じであることです。GAT は、クエリノードからキーノードへのアテンションを次のように計算します

どのクエリノードでも、キーのアテンションランク () は以下にのみ依存することに注意してください。したがって、キーのアテンションランクはすべてのクエリで同じ(静的)ままです。

GATv2はアテンションメカニズムを変更することで動的なアテンションを可能にします。

この論文は、GATの静的注意メカニズムが、合成辞書検索データセットのグラフ問題の一部で失敗することを示しています。これは完全に接続された二部グラフで、一方のノード(クエリノード)にはキーが関連付けられ、もう一方のノードセットにはキーと値の両方が関連付けられています。目標は、クエリノードの値を予測することです。GAT は静的処理が制限されているため、このタスクは失敗します。

これは、Coraデータセットで2層GATv2をトレーニングするためのトレーニングコードです

57import torch
58from torch import nn
59
60from labml_helpers.module import Module

グラフアテンション v2 レイヤー

これはシングルグラフアテンションv2レイヤーです。GATv2は、このような複数のレイヤーで構成されています。入力として、where を、出力として、where を取ります。

63class GraphAttentionV2Layer(Module):
  • in_features 、はノードあたりの入力フィーチャの数です
  • out_features 、はノードごとの出力フィーチャの数です
  • n_heads 、はアテンション・ヘッドの数
  • is_concat マルチヘッドの結果を連結すべきか平均化すべきか
  • dropout は脱落確率です
  • leaky_relu_negative_slope リークのあるリレーアクティベーションの負の傾きです
  • share_weights に設定するとTrue 、すべてのエッジのソースノードとターゲットノードに同じマトリックスが適用されます
76    def __init__(self, in_features: int, out_features: int, n_heads: int,
77                 is_concat: bool = True,
78                 dropout: float = 0.6,
79                 leaky_relu_negative_slope: float = 0.2,
80                 share_weights: bool = False):
90        super().__init__()
91
92        self.is_concat = is_concat
93        self.n_heads = n_heads
94        self.share_weights = share_weights

頭あたりの寸法数の計算

97        if is_concat:
98            assert out_features % n_heads == 0

複数のヘッドを連結する場合

100            self.n_hidden = out_features // n_heads
101        else:

複数のヘッドを平均化する場合

103            self.n_hidden = out_features

初期ソース変換用の線形レイヤー。つまり、自己処理の前にソースノードの埋め込みを変換する

107        self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

share_weights True ターゲットノードに同じリニアレイヤーが使用されている場合

109        if share_weights:
110            self.linear_r = self.linear_l
111        else:
112            self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

アテンションスコアを計算する線形レイヤー

114        self.attn = nn.Linear(self.n_hidden, 1, bias=False)

アテンションスコアのアクティベーション

116        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

注意力を計算するソフトマックス

118        self.softmax = nn.Softmax(dim=1)

注目すべきドロップアウト層

120        self.dropout = nn.Dropout(dropout)
  • hはシェイプの入力ノード埋め込みです。[n_nodes, in_features]
  • adj_mat [n_nodes, n_nodes, n_heads] は形状の隣接行列です。[n_nodes, n_nodes, 1] 各ヘッドの隣接関係が同じなので、形状を使用します。隣接マトリックスは、ノード間のエッジ (または接続) を表します。adj_mat[i][j] True i ノード間でエッジがある場合ですj
122    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

ノード数

132        n_nodes = h.shape[0]

各ヘッドの初期変換。線形変換を 2 回行い、それを各ヘッドに分割します

138        g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
139        g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)

アテンションスコアの計算

これらは頭ごとに計算します。わかりやすくするために省略しました。

ノードごとのアテンションスコア(重要度)です。これを頭ごとに計算します。

アテンションスコアを計算するアテンションメカニズムです。紙は合計しその後にAとが続き、重みベクトルを使用して線形変換を行います

注:この論文では、どちらがここで使用している定義と同等であるかが説明されています。

まず、すべてのペアを計算します.

g_l_repeat n_nodes 各ノードの埋め込みが何度も繰り返される場所を取得します。

177        g_l_repeat = g_l.repeat(n_nodes, 1, 1)

g_r_repeat_interleave n_nodes 各ノードの埋め込みが何度も繰り返される場所を取得します。

182        g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)

次に、2 つのテンソルを追加して

190        g_sum = g_l_repeat + g_r_repeat_interleave

g_sum[i, j] そのように形を変えてください

192        g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)

e 形状の計算 [n_nodes, n_nodes, n_heads, 1]

200        e = self.attn(self.activation(g_sum))

サイズの最後のディメンションを削除 1

202        e = e.squeeze(-1)

隣接マトリックスは、[n_nodes, n_nodes, n_heads] またはの形状でなければなりません [n_nodes, n_nodes, 1]

206        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
207        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
208        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

隣接マトリックスに基づくマスク。からまでのエッジがない場合は、に設定されます

211        e = e.masked_fill(adj_mat == 0, float('-inf'))

次に、アテンションスコア (または係数) を正規化します

は接続先のノードセットがどこにあるか

そのためには、「未接続」を「未接続」に設定することでペアが接続されていない状態になります

221        a = self.softmax(e)

ドロップアウト正則化を適用

224        a = self.dropout(a)

各ヘッドの最終出力を計算

228        attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)

ヘッドを連結してください

231        if self.is_concat:

233            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)

頭の中を平均して

235        else:

237            return attn_res.mean(dim=1)