11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigs
Coraデータセットは研究論文のデータセットです。各論文には、単語の存在を示すバイナリ特徴ベクトルが与えられます。各論文は7つのクラスのいずれかに分類されます。データセットには引用ネットワークもあります
。論文はグラフの節点で、端は引用です。
タスクは、特徴ベクトルと引用ネットワークを入力として、ノードを7つのクラスに分類することです。
26class CoraDataset:
各ノードのラベル
41 labels: torch.Tensor
クラス名と一意の整数インデックスのセット
43 classes: Dict[str, int]
全ノードの特徴ベクトル
45 features: torch.Tensor
エッジ情報を含む隣接マトリックス。adj_mat[i][j]
True
i
j
もしもから端があったらね
48 adj_mat: torch.Tensor
データセットのダウンロード
50 @staticmethod
51 def _download():
55 if not (lab.get_data_path() / 'cora').exists():
56 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57 lab.get_data_path() / 'cora.tgz')
58 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
データセットの読み込み
60 def __init__(self, include_edges: bool = True):
エッジを含めるかどうか。これは、引用ネットワークを無視すると精度がどれだけ失われるかをテストするものです
。67 self.include_edges = include_edges
データセットのダウンロード
70 self._download()
論文ID、特徴ベクトル、ラベルを読む
73 with monit.section('Read content file'):
74 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
引用をロードします。整数のペアのリストです。
76 with monit.section('Read citations file'):
77 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
特徴ベクトルを取得
80 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
特徴ベクトルを正規化
82 self.features = features / features.sum(dim=1, keepdim=True)
クラス名を取得し、それぞれに一意の整数を割り当てます。
85 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
ラベルをそれらの整数として取得
87 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
紙の ID を入手
90 paper_ids = np.array(content[:, 0], dtype=np.int32)
紙IDと索引のマップ
92 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
空の隣接行列-単位行列
95 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
引用文献を隣接マトリックスに記入
98 if self.include_edges:
99 for e in citations:
一対のペーパーインデックス
101 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
対称的なグラフを作成します。紙が参照している紙の場合は、端を端から端に、端を端として配置します。
105 self.adj_mat[e1][e2] = True
106 self.adj_mat[e2][e1] = True
109class GAT(Module):
in_features
はノードあたりのフィーチャ数n_hidden
は最初のグラフアテンションレイヤーに含まれるフィーチャの数ですn_classes
はクラスの数n_heads
グラフ・アテンション・レイヤーのヘッド数ですdropout
は脱落確率です116 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
124 super().__init__()
ヘッドを連結する最初のグラフ・アテンション・レイヤー
127 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
最初のグラフアテンションレイヤー後のアクティベーション機能
129 self.activation = nn.ELU()
ヘッドを平均化する最後のグラフ・アテンション・レイヤー
131 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
ドロップアウト
133 self.dropout = nn.Dropout(dropout)
x
は形状の特徴ベクトルです [n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は次の形式の隣接行列です [n_nodes, n_nodes, 1]
135 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
入力にドロップアウトを適用
142 x = self.dropout(x)
最初のグラフアテンションレイヤー
144 x = self.layer1(x, adj_mat)
アクティベーション機能
146 x = self.activation(x)
ドロップアウト
148 x = self.dropout(x)
ロジットの出力レイヤー (アクティベーションなし)
150 return self.output(x, adj_mat)
精度を計算する簡単な関数
153def accuracy(output: torch.Tensor, labels: torch.Tensor):
157 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
160class Configs(BaseConfigs):
モデル
166 model: GAT
トレーニングするノード数
168 training_samples: int = 500
入力内のノードあたりのフィーチャ数
170 in_features: int
最初のグラフアテンションレイヤーに含まれるフィーチャの数
172 n_hidden: int = 64
ヘッド数
174 n_heads: int = 8
分類するクラス数
176 n_classes: int
脱落確率
178 dropout: float = 0.6
引用ネットワークを含めるかどうか
180 include_edges: bool = True
データセット
182 dataset: CoraDataset
トレーニングの反復回数
184 epochs: int = 1_000
損失関数
186 loss_func = nn.CrossEntropyLoss()
191 device: torch.device = DeviceConfigs()
オプティマイザー
193 optimizer: torch.optim.Adam
データセットが小さいので、フルバッチトレーニングを行います。サンプリングしてトレーニングする場合、トレーニングステップごとに一連のノードと、選択したノードにまたがるエッジをサンプリングする必要があります。
195 def run(self):
特徴ベクトルをデバイスに移動します
205 features = self.dataset.features.to(self.device)
ラベルをデバイスに移動
207 labels = self.dataset.labels.to(self.device)
隣接マトリックスをデバイスに移動
209 edges_adj = self.dataset.adj_mat.to(self.device)
頭部に空の 3 番目のディメンションを追加
211 edges_adj = edges_adj.unsqueeze(-1)
ランダムインデックス
214 idx_rand = torch.randperm(len(labels))
トレーニング用ノード
216 idx_train = idx_rand[:self.training_samples]
検証用ノード
218 idx_valid = idx_rand[self.training_samples:]
トレーニングループ
221 for epoch in monit.loop(self.epochs):
モデルをトレーニングモードに設定
223 self.model.train()
すべてのグラデーションをゼロにする
225 self.optimizer.zero_grad()
モデルの評価
227 output = self.model(features, edges_adj)
トレーニングノードで損失を被る
229 loss = self.loss_func(output[idx_train], labels[idx_train])
勾配の計算
231 loss.backward()
最適化の一歩を踏み出す
233 self.optimizer.step()
損失を記録する
235 tracker.add('loss.train', loss)
精度を記録する
237 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
検証用にモードを評価モードに設定
240 self.model.eval()
勾配を計算する必要はありません
243 with torch.no_grad():
モデルを再度評価してください
245 output = self.model(features, edges_adj)
検証ノードの損失の計算
247 loss = self.loss_func(output[idx_valid], labels[idx_valid])
損失を記録する
249 tracker.add('loss.valid', loss)
精度を記録する
251 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
ログを保存
254 tracker.save()
Cora データセットの作成
257@option(Configs.dataset)
258def cora_dataset(c: Configs):
262 return CoraDataset(c.include_edges)
クラス数を取得
266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
入力内のフィーチャの数
268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
GAT モデルの作成
271@option(Configs.model)
272def gat_model(c: Configs):
276 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
設定可能なオプティマイザーの作成
279@option(Configs.optimizer)
280def _optimizer(c: Configs):
284 opt_conf = OptimizerConfigs()
285 opt_conf.parameters = c.model.parameters()
286 return opt_conf
289def main():
構成の作成
291 conf = Configs()
テストを作成
293 experiment.create(name='gat')
構成を計算します。
295 experiment.configs(conf, {
アダム・オプティマイザー
297 'optimizer.optimizer': 'Adam',
298 'optimizer.learning_rate': 5e-3,
299 'optimizer.weight_decay': 5e-4,
300 })
実験を開始して見る
303 with experiment.start():
トレーニングを実行
305 conf.run()
309if __name__ == '__main__':
310 main()