Cora データセットでのグラフアテンションネットワーク v2 (GATv2) のトレーニング

11import torch
12from torch import nn
13
14from labml import experiment
15from labml.configs import option
16from labml_helpers.module import Module
17from labml_nn.graphs.gat.experiment import Configs as GATConfigs
18from labml_nn.graphs.gatv2 import GraphAttentionV2Layer

グラフアテンションネットワーク v2 (GATv2)

このグラフアテンションネットワークには 2 つのグラフアテンションレイヤーがあります

21class GATv2(Module):
  • in_features はノードあたりのフィーチャ数
  • n_hidden は最初のグラフアテンションレイヤーに含まれるフィーチャの数です
  • n_classes はクラスの数
  • n_heads グラフ・アテンション・レイヤーのヘッド数です
  • dropout は脱落確率です
  • share_weights True に設定すると、すべてのエッジのソースノードとターゲットノードに同じマトリックスが適用されます
28    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
29                 share_weights: bool = True):
38        super().__init__()

ヘッドを連結する最初のグラフ・アテンション・レイヤー

41        self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads,
42                                            is_concat=True, dropout=dropout, share_weights=share_weights)

最初のグラフアテンションレイヤー後のアクティベーション機能

44        self.activation = nn.ELU()

ヘッドを平均化する最後のグラフ・アテンション・レイヤー

46        self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1,
47                                            is_concat=False, dropout=dropout, share_weights=share_weights)

ドロップアウト

49        self.dropout = nn.Dropout(dropout)
  • x は形状の特徴ベクトルです [n_nodes, in_features]
  • adj_mat [n_nodes, n_nodes, n_heads] は次の形式の隣接行列です [n_nodes, n_nodes, 1]
51    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

入力にドロップアウトを適用

58        x = self.dropout(x)

最初のグラフアテンションレイヤー

60        x = self.layer1(x, adj_mat)

アクティベーション機能

62        x = self.activation(x)

ドロップアウト

64        x = self.dropout(x)

ロジットの出力レイヤー (アクティベーションなし)

66        return self.output(x, adj_mat)
69class Configs(GATConfigs):

エッジのソースノードとターゲットノードのウェイトを共有するかどうか

78    share_weights: bool = False

モデルを設定する

80    model: GATv2 = 'gat_v2_model'

GATv2 モデルの作成

83@option(Configs.model)
84def gat_v2_model(c: Configs):
88    return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
91def main():

構成の作成

93    conf = Configs()

テストを作成

95    experiment.create(name='gatv2')

構成を計算します。

97    experiment.configs(conf, {

アダム・オプティマイザー

99        'optimizer.optimizer': 'Adam',
100        'optimizer.learning_rate': 5e-3,
101        'optimizer.weight_decay': 5e-4,
102
103        'dropout': 0.7,
104    })

実験を開始して見る

107    with experiment.start():

トレーニングを実行

109        conf.run()

113if __name__ == '__main__':
114    main()