リカレントハイウェイネットワーク

これは、リカレントハイウェイネットワークのPyTorch実装です

11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module

リカレントハイウェイネットワークセル

これは方程式を実装します。

どこ

そしてにとって

要素ごとの乗算の略です。

ここでは、論文の表記にいくつか変更を加えました。時間との混同を避けるため、ゲートは紙に載っていた「」 で表している。複数のレイヤーと混同しないように、紙の奥行きと全体の奥行きを紙の代わりに使用しています

また、方程式に含まれる重み行列とバイアスベクトルを線形変換に置き換えました。これが実装のようになるためです。

論文に記載されているように、ウエイトタイイングを実施しています。

19class RHNCell(Module):

input_size は入力のフィーチャ長、hidden_size はセルのフィーチャ長です。depth です

57    def __init__(self, input_size: int, hidden_size: int, depth: int):
63        super().__init__()
64
65        self.hidden_size = hidden_size
66        self.depth = depth

、を単一の線形レイヤーと組み合わせます。その後、結果を分割しておよびのコンポーネントを取得できます。これがおよび用です。

70        self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])

同様に、とを組み合わせます

73        self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)

x [batch_size, input_size] s [batch_size, hidden_size] 形があって形がある

75    def forward(self, x: torch.Tensor, s: torch.Tensor):

繰り返し

82        for d in range(self.depth):

との線形変換の連結を計算します

84            if d == 0:

入力は、の場合にのみ使用されます

86                hg = self.input_lin(x) + self.hidden_lin[d](s)
87            else:
88                hg = self.hidden_lin[d](s)

hg 前半を使うと

96            h = torch.tanh(hg[:, :self.hidden_size])

hg の後半を使うと

103            g = torch.sigmoid(hg[:, self.hidden_size:])
104
105            s = h * g + s * (1 - g)
106
107        return s

多層リカレントハイウェイネットワーク

110class RHN(Module):

n_layers depth それぞれ奥行きのある高速道路ネットワークレイヤーから成るネットワークを作成します。

115    def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120        super().__init__()
121        self.n_layers = n_layers
122        self.hidden_size = hidden_size

レイヤーごとにセルを作成します。最初のレイヤーだけが直接入力を取得することに注意してください。残りのレイヤーは、下のレイヤーから入力を取得します

125        self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126                                   [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])

x [seq_len, batch_size, input_size] state [batch_size, hidden_size] 形があって形がある

128    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133        time_steps, batch_size = x.shape[:2]

次の場合にステートを初期化します None

136        if state is None:
137            s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138        else:

ステートを逆にスタックして各レイヤーの状態を取得します。

📝 テンソル自体で作業することもできますが、デバッグは簡単です

142            s = torch.unbind(state)

各タイムステップで最終レイヤーの出力を収集する配列。

145        out = []

タイムステップごとにネットワーク経由で実行

148        for t in range(time_steps):

最初のレイヤーへの入力は入力そのものです

150            inp = x[t]

レイヤーをループする

152            for layer in range(self.n_layers):

レイヤーの状態を取得

154                s[layer] = self.cells[layer](inp, s[layer])

次のレイヤーへの入力は、このレイヤーの状態です

156                inp = s[layer]

最終レイヤーの出力を集める

158            out.append(s[-1])

出力とステートを積み重ねる

161        out = torch.stack(out)
162        s = torch.stack(s)
163
164        return out, s