これは、「グラフアテンションネットワークはどの程度注意深いのか?」という論文のGATv2演算子をPyTorchで実装したものです。
。GATv2は、GATと同様にグラフデータを処理します。グラフは、ノードとノードを接続するエッジで構成されます。たとえば、Coraデータセットでは、ノードは研究論文で、端は論文をつなぐ引用です
。GATv2 オペレータは、標準 GAT のスタティックアテンションの問題を解決します。スタティックアテンションとは、どのクエリノードでもキーノードへのアテンションのランク(順序)が同じであることです。GAT は、クエリノードからキーノードへのアテンションを次のように計算します
。どのクエリノードでも、キーのアテンションランク () は以下にのみ依存することに注意してください。したがって、キーのアテンションランクはすべてのクエリで同じ(静的)ままです。
GATv2はアテンションメカニズムを変更することで動的なアテンションを可能にします。
この論文は、GATの静的注意メカニズムが、合成辞書検索データセットのグラフ問題の一部で失敗することを示しています。これは完全に接続された二部グラフで、一方のノード(クエリノード)にはキーが関連付けられ、もう一方のノードセットにはキーと値の両方が関連付けられています。目標は、クエリノードの値を予測することです。GAT は静的処理が制限されているため、このタスクは失敗します。
57import torch
58from torch import nn
59
60from labml_helpers.module import Module
これはシングルグラフアテンションv2レイヤーです。GATv2は、このような複数のレイヤーで構成されています。入力として、where を、出力として、where を取ります。
63class GraphAttentionV2Layer(Module):
in_features
、はノードあたりの入力フィーチャの数ですout_features
、はノードごとの出力フィーチャの数ですn_heads
、はアテンション・ヘッドの数is_concat
マルチヘッドの結果を連結すべきか平均化すべきかdropout
は脱落確率ですleaky_relu_negative_slope
リークのあるリレーアクティベーションの負の傾きですshare_weights
に設定するとTrue
、すべてのエッジのソースノードとターゲットノードに同じマトリックスが適用されます76 def __init__(self, in_features: int, out_features: int, n_heads: int,
77 is_concat: bool = True,
78 dropout: float = 0.6,
79 leaky_relu_negative_slope: float = 0.2,
80 share_weights: bool = False):
90 super().__init__()
91
92 self.is_concat = is_concat
93 self.n_heads = n_heads
94 self.share_weights = share_weights
頭あたりの寸法数の計算
97 if is_concat:
98 assert out_features % n_heads == 0
複数のヘッドを連結する場合
100 self.n_hidden = out_features // n_heads
101 else:
複数のヘッドを平均化する場合
103 self.n_hidden = out_features
初期ソース変換用の線形レイヤー。つまり、自己処理の前にソースノードの埋め込みを変換する
107 self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
share_weights
True
ターゲットノードに同じリニアレイヤーが使用されている場合
109 if share_weights:
110 self.linear_r = self.linear_l
111 else:
112 self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
アテンションスコアを計算する線形レイヤー
114 self.attn = nn.Linear(self.n_hidden, 1, bias=False)
アテンションスコアのアクティベーション
116 self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
注意力を計算するソフトマックス
118 self.softmax = nn.Softmax(dim=1)
注目すべきドロップアウト層
120 self.dropout = nn.Dropout(dropout)
h
、はシェイプの入力ノード埋め込みです。[n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は形状の隣接行列です。[n_nodes, n_nodes, 1]
各ヘッドの隣接関係が同じなので、形状を使用します。隣接マトリックスは、ノード間のエッジ (または接続) を表します。adj_mat[i][j]
True
i
ノード間でエッジがある場合ですj
。122 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
ノード数
132 n_nodes = h.shape[0]
各ヘッドの初期変換。線形変換を 2 回行い、それを各ヘッドに分割します
。138 g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
139 g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)
これらは頭ごとに計算します。わかりやすくするために省略しました。
ノードごとのアテンションスコア(重要度)です。これを頭ごとに計算します。
アテンションスコアを計算するアテンションメカニズムです。紙は合計し、その後にAとが続き、重みベクトルを使用して線形変換を行います
注:この論文では、どちらがここで使用している定義と同等であるかが説明されています。
177 g_l_repeat = g_l.repeat(n_nodes, 1, 1)
g_r_repeat_interleave
n_nodes
各ノードの埋め込みが何度も繰り返される場所を取得します。
182 g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)
次に、2 つのテンソルを追加して
190 g_sum = g_l_repeat + g_r_repeat_interleave
g_sum[i, j]
そのように形を変えてください
192 g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)
e
形状の計算 [n_nodes, n_nodes, n_heads, 1]
200 e = self.attn(self.activation(g_sum))
サイズの最後のディメンションを削除 1
202 e = e.squeeze(-1)
隣接マトリックスは、[n_nodes, n_nodes, n_heads]
またはの形状でなければなりません [n_nodes, n_nodes, 1]
206 assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
207 assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
208 assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
隣接マトリックスに基づくマスク。からまでのエッジがない場合は、に設定されます。
211 e = e.masked_fill(adj_mat == 0, float('-inf'))
221 a = self.softmax(e)
ドロップアウト正則化を適用
224 a = self.dropout(a)
各ヘッドの最終出力を計算
228 attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)
ヘッドを連結してください
231 if self.is_concat:
233 return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
頭の中を平均して
235 else:
237 return attn_res.mean(dim=1)