Cora データセットのグラフアテンションネットワーク (GAT) のトレーニング

11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs

コーラデータセット

Coraデータセットは研究論文のデータセットです。各論文には、単語の存在を示すバイナリ特徴ベクトルが与えられます。各論文は7つのクラスのいずれかに分類されます。データセットには引用ネットワークもあります

論文はグラフの節点で、端は引用です。

タスクは、特徴ベクトルと引用ネットワークを入力として、ノードを7つのクラスに分類することです。

26class CoraDataset:

各ノードのラベル

41    labels: torch.Tensor

クラス名と一意の整数インデックスのセット

43    classes: Dict[str, int]

全ノードの特徴ベクトル

45    features: torch.Tensor

エッジ情報を含む隣接マトリックス。adj_mat[i][j] True i j もしもから端があったらね

48    adj_mat: torch.Tensor

データセットのダウンロード

50    @staticmethod
51    def _download():
55        if not (lab.get_data_path() / 'cora').exists():
56            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57                                   lab.get_data_path() / 'cora.tgz')
58            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())

データセットの読み込み

60    def __init__(self, include_edges: bool = True):

エッジを含めるかどうか。これは、引用ネットワークを無視すると精度がどれだけ失われるかをテストするものです

67        self.include_edges = include_edges

データセットのダウンロード

70        self._download()

論文ID、特徴ベクトル、ラベルを読む

73        with monit.section('Read content file'):
74            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))

引用をロードします。整数のペアのリストです。

76        with monit.section('Read citations file'):
77            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)

特徴ベクトルを取得

80        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))

特徴ベクトルを正規化

82        self.features = features / features.sum(dim=1, keepdim=True)

クラス名を取得し、それぞれに一意の整数を割り当てます。

85        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}

ラベルをそれらの整数として取得

87        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)

紙の ID を入手

90        paper_ids = np.array(content[:, 0], dtype=np.int32)

紙IDと索引のマップ

92        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}

空の隣接行列-単位行列

95        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)

引用文献を隣接マトリックスに記入

98        if self.include_edges:
99            for e in citations:

一対のペーパーインデックス

101                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]

対称的なグラフを作成します。紙が参照している紙の場合は、端を端から端に、端を端として配置します。

105                self.adj_mat[e1][e2] = True
106                self.adj_mat[e2][e1] = True

グラフ・アテンション・ネットワーク (GAT)

このグラフアテンションネットワークには 2 つのグラフアテンションレイヤーがあります

109class GAT(Module):
  • in_features はノードあたりのフィーチャ数
  • n_hidden は最初のグラフアテンションレイヤーに含まれるフィーチャの数です
  • n_classes はクラスの数
  • n_heads グラフ・アテンション・レイヤーのヘッド数です
  • dropout は脱落確率です
116    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124        super().__init__()

ヘッドを連結する最初のグラフ・アテンション・レイヤー

127        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)

最初のグラフアテンションレイヤー後のアクティベーション機能

129        self.activation = nn.ELU()

ヘッドを平均化する最後のグラフ・アテンション・レイヤー

131        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)

ドロップアウト

133        self.dropout = nn.Dropout(dropout)
  • x は形状の特徴ベクトルです [n_nodes, in_features]
  • adj_mat [n_nodes, n_nodes, n_heads] は次の形式の隣接行列です [n_nodes, n_nodes, 1]
135    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

入力にドロップアウトを適用

142        x = self.dropout(x)

最初のグラフアテンションレイヤー

144        x = self.layer1(x, adj_mat)

アクティベーション機能

146        x = self.activation(x)

ドロップアウト

148        x = self.dropout(x)

ロジットの出力レイヤー (アクティベーションなし)

150        return self.output(x, adj_mat)

精度を計算する簡単な関数

153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)

コンフィギュレーション

160class Configs(BaseConfigs):

モデル

166    model: GAT

トレーニングするノード数

168    training_samples: int = 500

入力内のノードあたりのフィーチャ数

170    in_features: int

最初のグラフアテンションレイヤーに含まれるフィーチャの数

172    n_hidden: int = 64

ヘッド数

174    n_heads: int = 8

分類するクラス数

176    n_classes: int

脱落確率

178    dropout: float = 0.6

引用ネットワークを含めるかどうか

180    include_edges: bool = True

データセット

182    dataset: CoraDataset

トレーニングの反復回数

184    epochs: int = 1_000

損失関数

186    loss_func = nn.CrossEntropyLoss()

トレーニングするデバイス

これによりデバイスの設定が作成されるので、設定値を渡すことでデバイスを変更できます

191    device: torch.device = DeviceConfigs()

オプティマイザー

193    optimizer: torch.optim.Adam

トレーニングループ

データセットが小さいので、フルバッチトレーニングを行います。サンプリングしてトレーニングする場合、トレーニングステップごとに一連のノードと、選択したノードにまたがるエッジをサンプリングする必要があります。

195    def run(self):

特徴ベクトルをデバイスに移動します

205        features = self.dataset.features.to(self.device)

ラベルをデバイスに移動

207        labels = self.dataset.labels.to(self.device)

隣接マトリックスをデバイスに移動

209        edges_adj = self.dataset.adj_mat.to(self.device)

頭部に空の 3 番目のディメンションを追加

211        edges_adj = edges_adj.unsqueeze(-1)

ランダムインデックス

214        idx_rand = torch.randperm(len(labels))

トレーニング用ノード

216        idx_train = idx_rand[:self.training_samples]

検証用ノード

218        idx_valid = idx_rand[self.training_samples:]

トレーニングループ

221        for epoch in monit.loop(self.epochs):

モデルをトレーニングモードに設定

223            self.model.train()

すべてのグラデーションをゼロにする

225            self.optimizer.zero_grad()

モデルの評価

227            output = self.model(features, edges_adj)

トレーニングノードで損失を被る

229            loss = self.loss_func(output[idx_train], labels[idx_train])

勾配の計算

231            loss.backward()

最適化の一歩を踏み出す

233            self.optimizer.step()

損失を記録する

235            tracker.add('loss.train', loss)

精度を記録する

237            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

検証用にモードを評価モードに設定

240            self.model.eval()

勾配を計算する必要はありません

243            with torch.no_grad():

モデルを再度評価してください

245                output = self.model(features, edges_adj)

検証ノードの損失の計算

247                loss = self.loss_func(output[idx_valid], labels[idx_valid])

損失を記録する

249                tracker.add('loss.valid', loss)

精度を記録する

251                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

ログを保存

254            tracker.save()

Cora データセットの作成

257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262    return CoraDataset(c.include_edges)

クラス数を取得

266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))

入力内のフィーチャの数

268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])

GAT モデルの作成

271@option(Configs.model)
272def gat_model(c: Configs):
276    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)

設定可能なオプティマイザーの作成

279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284    opt_conf = OptimizerConfigs()
285    opt_conf.parameters = c.model.parameters()
286    return opt_conf
289def main():

構成の作成

291    conf = Configs()

テストを作成

293    experiment.create(name='gat')

構成を計算します。

295    experiment.configs(conf, {

アダム・オプティマイザー

297        'optimizer.optimizer': 'Adam',
298        'optimizer.learning_rate': 5e-3,
299        'optimizer.weight_decay': 5e-4,
300    })

実験を開始して見る

303    with experiment.start():

トレーニングを実行

305        conf.run()

309if __name__ == '__main__':
310    main()