これは、論文「PonderNet: 熟考を学ぼう」をPyTorchで実装したものです。
PonderNet は入力に基づいて計算を調整します。入力に基づいてリカレントネットワークで実行するステップの数を変更します。PonderNetはこれを端から端までの勾配降下法で学習します
。PonderNetには次のような形式のステップ機能があります
ここで、は入力、は状態、はステップでの予測、は現在のステップで停止 (停止) する確率です。
任意のニューラルネットワーク(LSTM、MLP、GRU、アテンションレイヤーなど)にすることができます。
その場合、ステップで停止する無条件の確率は、
これは、前のどのステップでも停止せず、ステップで停止する確率です。
推論中は、停止確率に基づいてサンプリングして停止し、停止層での予測を最終出力として取得します。
トレーニング中、すべてのレイヤーから予測を取得し、各レイヤーの損失を計算します。次に、各レイヤーで停止する確率に基づいて、損失の加重平均を取ります
。ステップ機能は、から寄付されたステップの最大数に適用されます。
PonderNet の全体的な損失は
はターゲットと予測の間の正規損失関数です。
はでパラメータ化された幾何分布です。関係ありません。ただ紙と同じ表記法にこだわっているだけです。
。正則化損失は、ネットワークを手順を実行する方向に偏らせ、すべてのステップでゼロ以外の確率を誘発します。つまり、探索が促進されます。
パリティタスクで PonderNet experiment.py
をトレーニングするためのトレーニングコードは次のとおりです。
63from typing import Tuple
64
65import torch
66from torch import nn
67
68from labml_helpers.module import Module
これは、ステップ関数としてGRU Cellを使用するシンプルなモデルです。
このモデルは、n_elems
入力がのベクトルであるパリティタスク用です。ベクトルの各要素は0
、1
-1
またはのいずれかで、出力はパリティです。これは 1
s の数が奇数の場合は true、それ以外の場合は false のバイナリ値です
モデルの予測は、パリティが存在する対数確率です。
71class ParityPonderGRU(Module):
n_elems
は入力ベクトルの要素数ですn_hidden
GRU の状態ベクトルサイズですmax_steps
は、最大ステップ数です 85 def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
91 super().__init__()
92
93 self.max_steps = max_steps
94 self.n_hidden = n_hidden
GRU
98 self.gru = nn.GRUCell(n_elems, n_hidden)
とを入力として連結するレイヤーを使用することもできますが、簡略化のためにこれを使用しました。
102 self.output_layer = nn.Linear(n_hidden, 1)
104 self.lambda_layer = nn.Linear(n_hidden, 1)
105 self.lambda_prob = nn.Sigmoid()
推論時に計算が実際に停止するように推論中に設定するオプション
107 self.is_halt = False
x
形状の入力です [batch_size, n_elems]
これは4つのテンソルのタプルを出力します。
1。[N, batch_size]
形状2のテンソルで。[N, batch_size]
形状のテンソルでは、パリティの対数確率は3です。[batch_size]
シェイプ4の。[batch_size]
ステップで計算が中止された形状の
109 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
122 batch_size = x.shape[0]
初期状態を取得しました
125 h = x.new_zeros((x.shape[0], self.n_hidden))
126 h = self.gru(x, h)
保存するリストと
129 p = []
130 y = []
132 un_halted_prob = h.new_ones((batch_size,))
どのサンプルが計算を停止したかを維持するためのベクトル
135 halted = h.new_zeros((batch_size,))
ステップで計算が中止されたところ
137 p_m = h.new_zeros((batch_size,))
138 y_m = h.new_zeros((batch_size,))
ステップごとに繰り返す
141 for n in range(1, self.max_steps + 1):
最後のステップで停止する確率
143 if n == self.max_steps:
144 lambda_n = h.new_ones(h.shape[0])
146 else:
147 lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]
149 y_n = self.output_layer(h)[:, 0]
152 p_n = un_halted_prob * lambda_n
[更新]
154 un_halted_prob = un_halted_prob * (1 - lambda_n)
停止確率に基づく停止
157 halt = torch.bernoulli(lambda_n) * (1 - halted)
収集して
160 p.append(p_n)
161 y.append(y_n)
更新および現在のステップで停止された内容に基づく
164 p_m = p_m * (1 - halt) + p_n * halt
165 y_m = y_m * (1 - halt) + y_n * halt
停止したサンプルの更新
168 halted = halted + halt
次のステートを取得
170 h = self.gru(x, h)
すべてのサンプルが停止したら、計算を停止します
173 if self.is_halt and halted.sum() == batch_size:
174 break
177 return torch.stack(p), torch.stack(y), p_m, y_m
180class ReconstructionLoss(Module):
loss_func
は損失関数です 189 def __init__(self, loss_func: nn.Module):
193 super().__init__()
194 self.loss_func = loss_func
p
テンソルの形をしています [N, batch_size]
y_hat
テンソルの形をしています [N, batch_size, ...]
y
形状のターゲットです [batch_size, ...]
196 def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
合計
204 total_loss = p.new_tensor(0.)
最大まで繰り返す
206 for n in range(p.shape[0]):
各サンプルとその平均について
208 loss = (p[n] * self.loss_func(y_hat[n], y)).mean()
総損失に加算
210 total_loss = total_loss + loss
213 return total_loss
はでパラメータ化された幾何分布です。関係ありません。ただ紙と同じ表記法にこだわっているだけです。
。正則化損失は、ネットワークを手順を実行する方向に偏らせ、すべてのステップでゼロ以外の確率を誘発します。つまり、探索が促進されます。
216class RegularizationLoss(Module):
lambda_p
is -幾何分布の成功確率max_steps
が最高です。これを使って事前計算します 232 def __init__(self, lambda_p: float, max_steps: int = 1_000):
237 super().__init__()
計算する空のベクトル
240 p_g = torch.zeros((max_steps,))
242 not_halted = 1.
最大まで繰り返す max_steps
244 for k in range(max_steps):
246 p_g[k] = not_halted * lambda_p
[更新]
248 not_halted = not_halted * (1 - lambda_p)
[保存]
251 self.p_g = nn.Parameter(p_g, requires_grad=False)
KL-ダイバージェンスロス
254 self.kl_div = nn.KLDivLoss(reduction='batchmean')
p
テンソルの形をしています [N, batch_size]
256 def forward(self, p: torch.Tensor):
に転置 p
[batch_size, N]
261 p = p.transpose(0, 1)
バッチディメンション全体への展開と拡張
263 p_g = self.p_g[None, :p.shape[1]].expand_as(p)
268 return self.kl_div(p.log(), p_g)