これは論文の「グラフ・アテンション・ネットワーク」の PyTorch 実装です。
GAT はグラフデータを処理します。グラフは、ノードとノードを接続するエッジで構成されます。たとえば、Coraデータセットでは、ノードは研究論文で、端は論文をつなぐ引用です
。GATは、トランスフォーマーに似た、マスクされたセルフアテンションを使います。GATは、グラフアテンションレイヤーが互いに重なり合って構成されています。各グラフアテンションレイヤーは、入力としてノード埋め込みを取得し、変換された埋め込みを出力します。ノード埋め込みは、接続されている他のノードの埋め込みに注目します。グラフアテンションレイヤーの詳細は、実装とともに含まれています。
28import torch
29from torch import nn
30
31from labml_helpers.module import Module
これは単一のグラフアテンションレイヤーです。GAT はこのような複数のレイヤーで構成されています
。入力として、where を、出力として、where を取ります。
34class GraphAttentionLayer(Module):
in_features
、はノードあたりの入力フィーチャの数ですout_features
、はノードごとの出力フィーチャの数ですn_heads
、はアテンション・ヘッドの数is_concat
マルチヘッドの結果を連結すべきか平均化すべきかdropout
は脱落確率ですleaky_relu_negative_slope
リークのあるリレーアクティベーションの負の傾きです48 def __init__(self, in_features: int, out_features: int, n_heads: int,
49 is_concat: bool = True,
50 dropout: float = 0.6,
51 leaky_relu_negative_slope: float = 0.2):
60 super().__init__()
61
62 self.is_concat = is_concat
63 self.n_heads = n_heads
頭あたりの寸法数の計算
66 if is_concat:
67 assert out_features % n_heads == 0
複数のヘッドを連結する場合
69 self.n_hidden = out_features // n_heads
70 else:
複数のヘッドを平均化する場合
72 self.n_hidden = out_features
初期変換用の線形レイヤー。つまり、自己処理の前にノード埋め込みを変換する
76 self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
アテンションスコアを計算する線形レイヤー
78 self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
アテンションスコアのアクティベーション
80 self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
注意力を計算するソフトマックス
82 self.softmax = nn.Softmax(dim=1)
注目すべきドロップアウト層
84 self.dropout = nn.Dropout(dropout)
h
、はシェイプの入力ノード埋め込みです。[n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は形状の隣接行列です。[n_nodes, n_nodes, 1]
各ヘッドの隣接関係が同じなので、形状を使用します隣接マトリックスは、ノード間のエッジ (または接続) を表します。adj_mat[i][j]
True
i
ノード間でエッジがある場合ですj
。
86 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
ノード数
97 n_nodes = h.shape[0]
各ヘッドの初期変形。線形変換を1つ行い、それを頭ごとに分割します。
102 g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
これらは頭ごとに計算します。わかりやすくするために省略しました。
ノードごとのアテンションスコア(重要度)です。これを頭ごとに計算します。
アテンションスコアを計算するアテンションメカニズムです。この論文では、重みベクトルの後にaを連結し、線形変換を行います
。
133 g_repeat = g.repeat(n_nodes, 1, 1)
g_repeat_interleave
n_nodes
各ノードの埋め込みが何度も繰り返される場所を取得します。
138 g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
次に、連結して次のようになります。
146 g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
g_concat[i, j]
そのように形を変えてください
148 g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
e
形状の計算 [n_nodes, n_nodes, n_heads, 1]
156 e = self.activation(self.attn(g_concat))
サイズの最後のディメンションを削除 1
158 e = e.squeeze(-1)
隣接マトリックスは、[n_nodes, n_nodes, n_heads]
またはの形状でなければなりません [n_nodes, n_nodes, 1]
162 assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
163 assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
164 assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
隣接マトリックスに基づくマスク。からまでのエッジがない場合は、に設定されます。
167 e = e.masked_fill(adj_mat == 0, float('-inf'))
177 a = self.softmax(e)
ドロップアウト正則化を適用
180 a = self.dropout(a)
各ヘッドの最終出力を計算
注:このホワイトペーパーでは、Graph Attention Layer の実装からは省略し、他の PyTorch モジュールの定義に合わせて GAT モデルで使用しています。つまり、アクティベーションは別のレイヤーとして行われます。
189 attn_res = torch.einsum('ijh,jhf->ihf', a, g)
ヘッドを連結してください
192 if self.is_concat:
194 return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
頭の中を平均して
196 else:
198 return attn_res.mean(dim=1)