11import torch
12from torch import nn
13
14from labml import experiment
15from labml.configs import option
16from labml_helpers.module import Module
17from labml_nn.graphs.gat.experiment import Configs as GATConfigs
18from labml_nn.graphs.gatv2 import GraphAttentionV2Layer
21class GATv2(Module):
in_features
はノードあたりのフィーチャ数n_hidden
は最初のグラフアテンションレイヤーに含まれるフィーチャの数ですn_classes
はクラスの数n_heads
グラフ・アテンション・レイヤーのヘッド数ですdropout
は脱落確率ですshare_weights
True に設定すると、すべてのエッジのソースノードとターゲットノードに同じマトリックスが適用されます28 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
29 share_weights: bool = True):
38 super().__init__()
ヘッドを連結する最初のグラフ・アテンション・レイヤー
41 self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads,
42 is_concat=True, dropout=dropout, share_weights=share_weights)
最初のグラフアテンションレイヤー後のアクティベーション機能
44 self.activation = nn.ELU()
ヘッドを平均化する最後のグラフ・アテンション・レイヤー
46 self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1,
47 is_concat=False, dropout=dropout, share_weights=share_weights)
ドロップアウト
49 self.dropout = nn.Dropout(dropout)
x
は形状の特徴ベクトルです [n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は次の形式の隣接行列です [n_nodes, n_nodes, 1]
51 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
入力にドロップアウトを適用
58 x = self.dropout(x)
最初のグラフアテンションレイヤー
60 x = self.layer1(x, adj_mat)
アクティベーション機能
62 x = self.activation(x)
ドロップアウト
64 x = self.dropout(x)
ロジットの出力レイヤー (アクティベーションなし)
66 return self.output(x, adj_mat)
69class Configs(GATConfigs):
エッジのソースノードとターゲットノードのウェイトを共有するかどうか
78 share_weights: bool = False
モデルを設定する
80 model: GATv2 = 'gat_v2_model'
GATv2 モデルの作成
83@option(Configs.model)
84def gat_v2_model(c: Configs):
88 return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
91def main():
構成の作成
93 conf = Configs()
テストを作成
95 experiment.create(name='gatv2')
構成を計算します。
97 experiment.configs(conf, {
アダム・オプティマイザー
99 'optimizer.optimizer': 'Adam',
100 'optimizer.learning_rate': 5e-3,
101 'optimizer.weight_decay': 5e-4,
102
103 'dropout': 0.7,
104 })
実験を開始して見る
107 with experiment.start():
トレーニングを実行
109 conf.run()
113if __name__ == '__main__':
114 main()