PonderNet: 熟考することを学ぶ

これは、論文「PonderNet: 熟考を学ぼうをPyTorchで実装したものです

PonderNet は入力に基づいて計算を調整します。入力に基づいてリカレントネットワークで実行するステップの数を変更します。PonderNetはこれを端から端までの勾配降下法で学習します

PonderNetには次のような形式のステップ機能があります

ここで、は入力、は状態、はステップでの予測は現在のステップで停止 (停止) する確率です。

任意のニューラルネットワーク(LSTM、MLP、GRU、アテンションレイヤーなど)にすることができます。

その場合、ステップで停止する無条件の確率は、

これは、前のどのステップでも停止せず、ステップで停止する確率です。

推論中は、停止確率に基づいてサンプリングして停止し、停止層での予測を最終出力として取得します。

トレーニング中、すべてのレイヤーから予測を取得し、各レイヤーの損失を計算します。次に、各レイヤーで停止する確率に基づいて、損失の加重平均を取ります

ステップ機能は、から寄付されたステップの最大数に適用されます。

PonderNet の全体的な損失は

はターゲットと予測の間の正規損失関数です。

はカルバックとライブラーのダイバージェンスです

はでパラメータ化された幾何分布です関係ありません。ただ紙と同じ表記法にこだわっているだけです

正則化損失は、ネットワークを手順を実行する方向に偏らせ、すべてのステップでゼロ以外の確率を誘発します。つまり、探索が促進されます。

パリティタスクで PonderNet experiment.py をトレーニングするためのトレーニングコードは次のとおりです

63from typing import Tuple
64
65import torch
66from torch import nn
67
68from labml_helpers.module import Module

パリティタスク用の GRU 搭載の PonderNet

これは、ステップ関数としてGRU Cellを使用するシンプルなモデルです

このモデルは、n_elems 入力がのベクトルであるパリティタスク用です。ベクトルの各要素は01 -1 またはのいずれかで、出力はパリティです。これは 1 s の数が奇数の場合は true、それ以外の場合は false のバイナリ値です

モデルの予測は、パリティが存在する対数確率です。

71class ParityPonderGRU(Module):
  • n_elems は入力ベクトルの要素数です
  • n_hidden GRU の状態ベクトルサイズです
  • max_steps は、最大ステップ数です
85    def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
91        super().__init__()
92
93        self.max_steps = max_steps
94        self.n_hidden = n_hidden

GRU

98        self.gru = nn.GRUCell(n_elems, n_hidden)

とを入力として連結するレイヤーを使用することもできますが、簡略化のためにこれを使用しました。

102        self.output_layer = nn.Linear(n_hidden, 1)

104        self.lambda_layer = nn.Linear(n_hidden, 1)
105        self.lambda_prob = nn.Sigmoid()

推論時に計算が実際に停止するように推論中に設定するオプション

107        self.is_halt = False
  • x 形状の入力です [batch_size, n_elems]

これは4つのテンソルのタプルを出力します。

1。[N, batch_size] 形状2のテンソルで。[N, batch_size] 形状のテンソルでは、パリティの対数確率は3です。[batch_size] シェイプ4の。[batch_size] ステップで計算が中止された形状の

109    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

122        batch_size = x.shape[0]

初期状態を取得しました

125        h = x.new_zeros((x.shape[0], self.n_hidden))
126        h = self.gru(x, h)

保存するリストと

129        p = []
130        y = []

132        un_halted_prob = h.new_ones((batch_size,))

どのサンプルが計算を停止したかを維持するためのベクトル

135        halted = h.new_zeros((batch_size,))

ステップで計算が中止されたところ

137        p_m = h.new_zeros((batch_size,))
138        y_m = h.new_zeros((batch_size,))

ステップごとに繰り返す

141        for n in range(1, self.max_steps + 1):

最後のステップで停止する確率

143            if n == self.max_steps:
144                lambda_n = h.new_ones(h.shape[0])

146            else:
147                lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]

149            y_n = self.output_layer(h)[:, 0]

152            p_n = un_halted_prob * lambda_n

[更新]

154            un_halted_prob = un_halted_prob * (1 - lambda_n)

停止確率に基づく停止

157            halt = torch.bernoulli(lambda_n) * (1 - halted)

収集して

160            p.append(p_n)
161            y.append(y_n)

更新および現在のステップで停止された内容に基づく

164            p_m = p_m * (1 - halt) + p_n * halt
165            y_m = y_m * (1 - halt) + y_n * halt

停止したサンプルの更新

168            halted = halted + halt

次のステートを取得

170            h = self.gru(x, h)

すべてのサンプルが停止したら、計算を停止します

173            if self.is_halt and halted.sum() == batch_size:
174                break

177        return torch.stack(p), torch.stack(y), p_m, y_m

復興損失

はターゲットと予測の間の正規損失関数です。

180class ReconstructionLoss(Module):
  • loss_func は損失関数です
189    def __init__(self, loss_func: nn.Module):
193        super().__init__()
194        self.loss_func = loss_func
  • p テンソルの形をしています [N, batch_size]
  • y_hat テンソルの形をしています [N, batch_size, ...]
  • y 形状のターゲットです [batch_size, ...]
196    def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):

合計

204        total_loss = p.new_tensor(0.)

最大まで繰り返す

206        for n in range(p.shape[0]):

各サンプルとその平均について

208            loss = (p[n] * self.loss_func(y_hat[n], y)).mean()

総損失に加算

210            total_loss = total_loss + loss

213        return total_loss

正則化損失

はカルバックとライブラーのダイバージェンスです

はでパラメータ化された幾何分布です関係ありません。ただ紙と同じ表記法にこだわっているだけです

正則化損失は、ネットワークを手順を実行する方向に偏らせ、すべてのステップでゼロ以外の確率を誘発します。つまり、探索が促進されます。

216class RegularizationLoss(Module):
  • lambda_p is -幾何分布の成功確率
  • max_steps が最高です。これを使って事前計算します
232    def __init__(self, lambda_p: float, max_steps: int = 1_000):
237        super().__init__()

計算する空のベクトル

240        p_g = torch.zeros((max_steps,))

242        not_halted = 1.

最大まで繰り返す max_steps

244        for k in range(max_steps):

246            p_g[k] = not_halted * lambda_p

[更新]

248            not_halted = not_halted * (1 - lambda_p)

[保存]

251        self.p_g = nn.Parameter(p_g, requires_grad=False)

KL-ダイバージェンスロス

254        self.kl_div = nn.KLDivLoss(reduction='batchmean')
  • p テンソルの形をしています [N, batch_size]
256    def forward(self, p: torch.Tensor):

に転置 p [batch_size, N]

261        p = p.transpose(0, 1)

バッチディメンション全体への展開と拡張

263        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

KL ダイバージェンスを計算します。

PyTorch KL-ダイバージェンス実装は対数確率を受け入れます
268        return self.kl_div(p.log(), p_g)