ハイパーネットワーク-HyperLSTM

論文で紹介したHyperNetworksで紹介したHyperLSTMを、PyTorchを使ったアノテーション付きで実装しましたDavid Ha によるこのブログ記事では、ハイパーネットワークについてわかりやすく説明しています

シェイクスピアデータセットのテキストを予測するように HyperLSTM をトレーニングする実験を行っています。コードへのリンクは次のとおりです。experiment.py

Open In Colab

ハイパーネットワークは、小さいネットワークを使用して大きなネットワークの重みを生成します。静的ハイパーネットワークと動的ハイパーネットワークの 2 種類があります。静的ハイパーネットワークには、畳み込みネットワークのウェイト (カーネル) を生成する小規模なネットワークがあります。動的ハイパーネットワークは、ステップごとにリカレントニューラルネットワークのパラメーターを生成します。これは後者の実装です。

ダイナミック・ハイパーネットワーク

RNN では、パラメータは各ステップで一定に保たれます。ダイナミックハイパーネットワークは、ステップごとに異なるパラメーターを生成します。HyperLSTMはLSTMの構造ですが、各ステップのパラメータは小規模なLSTMネットワークによって変更されます

基本的には、動的ハイパーネットワークには小さな再帰ネットワークがあり、大きい方の再帰ネットワークの各パラメーターテンソルに対応する特徴ベクトルを生成します。たとえば、大きいネットワークに何らかのパラメータがあるとします。小さいネットワークは特徴ベクトルを生成し、線形変換として動的に計算します。たとえば、は 3 次元のテンソルパラメーターで、はテンソルとベクトルの乗算です。通常は、小規模な再帰ネットワークの出力を線形変換したものです

計算の代わりにウェイトスケーリングを行う

大規模なリカレントネットワークには、動的に計算されるパラメーターが大きくなります。これらは、特徴ベクトルの線形変換を使用して計算されます。そして、この変換にはさらに大きな重みテンソルが必要です。つまり形があれば形になります

これを解決するために、同じサイズの行列の各行を動的にスケーリングすることにより、再帰ネットワークの重みパラメーターを計算します。

ここではパラメータマトリックスです。

ここで、要素単位の乗算を意味するので、計算時にこれをさらに最適化できます。

72from typing import Optional, Tuple
73
74import torch
75from torch import nn
76
77from labml_helpers.module import Module
78from labml_nn.lstm import LSTMCell

ハイパーLSTMセル

HyperLSTM では、小さいネットワークと大きいネットワークの両方が LSTM 構造になっています。これは論文の付録A.2.2で定義されています

81class HyperLSTMCell(Module):

input_size は入力のサイズhidden_size は LSTM のサイズ、は小さい方の LSTM のサイズで、大きい方の外側の LSTM hyper_size の重みは変化します。n_z LSTM ウェイトの変更に使用される特徴ベクトルのサイズです

小さい方のLSTMの出力を使用して計算し、線形変換を使用します。再度、線形変換を使用して計算し、これらから行います。次に、これらを使用してメインLSTMの重みとバイアステンソルの行をスケーリングします

📝 との計算は2つの連続した線形変換なので、これらを1つの線形変換にまとめることができます。ただし、論文の説明と一致するようにこれを個別に実装しました。

89    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
107        super().__init__()

HyperLSTM への入力は、前のステップで外部の LSTM の入力と出力です。したがって、入力サイズはですhidden_size + input_size

HyperLSTM の出力はおよびです。

120        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)

🤔 論文ではタイプミスだと思うので明記されていました。

126        self.z_h = nn.Linear(hyper_size, 4 * n_z)

128        self.z_x = nn.Linear(hyper_size, 4 * n_z)

130        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)

133        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
134        self.d_h = nn.ModuleList(d_h)

136        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
137        self.d_x = nn.ModuleList(d_x)

139        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
140        self.d_b = nn.ModuleList(d_b)

ウェイトマトリックス

143        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])

ウェイトマトリックス

145        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])

レイヤー正規化

148        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
149        self.layer_norm_c = nn.LayerNorm(hidden_size)
151    def forward(self, x: torch.Tensor,
152                h: torch.Tensor, c: torch.Tensor,
153                h_hat: torch.Tensor, c_hat: torch.Tensor):

160        x_hat = torch.cat((h, x), dim=-1)

162        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)

165        z_h = self.z_h(h_hat).chunk(4, dim=-1)

167        z_x = self.z_x(h_hat).chunk(4, dim=-1)

169        z_b = self.z_b(h_hat).chunk(4, dim=-1)

をループで計算します

172        ifgo = []
173        for i in range(4):

175            d_h = self.d_h[i](z_h[i])

177            d_x = self.d_x[i](z_x[i])

184            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
185                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
186                self.d_b[i](z_b[i])
187
188            ifgo.append(self.layer_norm[i](y))

191        i, f, g, o = ifgo

194        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

197        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
198
199        return h_next, c_next, h_hat, c_hat

HyperLSTM モジュール

202class HyperLSTM(Module):

HyperLSTM n_layers のネットワークを作成します。

207    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
212        super().__init__()

サイズを保存して状態を初期化します

215        self.n_layers = n_layers
216        self.hidden_size = hidden_size
217        self.hyper_size = hyper_size

レイヤーごとにセルを作成します。最初のレイヤーだけが直接入力を取得することに注意してください。残りのレイヤーは、下のレイヤーから入力を取得します

221        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
222                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
223                                    range(n_layers - 1)])
  • x [n_steps, batch_size, input_size] 形があって
  • state のタプルです。[batch_size, hidden_size] [batch_size, hyper_size] 形があって形がある
225    def forward(self, x: torch.Tensor,
226                state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
233        n_steps, batch_size = x.shape[:2]

次の場合は、状態をゼロで初期化します None

236        if state is None:
237            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
238            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
240            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]

242        else:
243            (h, c, h_hat, c_hat) = state

テンソルを逆に積み重ねて各レイヤーの状態を取得します

📝 テンソル自体で作業することもできますが、デバッグは簡単です

247            h, c = list(torch.unbind(h)), list(torch.unbind(c))
248            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))

各ステップで最終レイヤーの出力を集める

251        out = []
252        for t in range(n_steps):

最初のレイヤーへの入力は入力そのものです

254            inp = x[t]

レイヤーをループする

256            for layer in range(self.n_layers):

レイヤーの状態を取得

258                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
259                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])

次のレイヤーへの入力は、このレイヤーの状態です

261                inp = h[layer]

最終レイヤーの出力を集める

263            out.append(h[-1])

出力とステートを積み重ねる

266        out = torch.stack(out)
267        h = torch.stack(h)
268        c = torch.stack(c)
269        h_hat = torch.stack(h_hat)
270        c_hat = torch.stack(c_hat)

273        return out, (h, c, h_hat, c_hat)