11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module
これは方程式を実装します。
どこ
そしてにとって
要素ごとの乗算の略です。
ここでは、論文の表記にいくつか変更を加えました。時間との混同を避けるため、ゲートは紙に載っていた「」 で表している。複数のレイヤーと混同しないように、紙の奥行きと全体の奥行きを紙の代わりに使用しています
。また、方程式に含まれる重み行列とバイアスベクトルを線形変換に置き換えました。これが実装のようになるためです。
論文に記載されているように、ウエイトタイイングを実施しています。
19class RHNCell(Module):
input_size
は入力のフィーチャ長、hidden_size
はセルのフィーチャ長です。depth
です。
57 def __init__(self, input_size: int, hidden_size: int, depth: int):
63 super().__init__()
64
65 self.hidden_size = hidden_size
66 self.depth = depth
と、を単一の線形レイヤーと組み合わせます。その後、結果を分割しておよびのコンポーネントを取得できます。これがおよび用です。
70 self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
同様に、とを組み合わせます。
73 self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
x
[batch_size, input_size]
s
[batch_size, hidden_size]
形があって形がある
75 def forward(self, x: torch.Tensor, s: torch.Tensor):
繰り返し
82 for d in range(self.depth):
との線形変換の連結を計算します
84 if d == 0:
入力は、の場合にのみ使用されます。
86 hg = self.input_lin(x) + self.hidden_lin[d](s)
87 else:
88 hg = self.hidden_lin[d](s)
96 h = torch.tanh(hg[:, :self.hidden_size])
103 g = torch.sigmoid(hg[:, self.hidden_size:])
104
105 s = h * g + s * (1 - g)
106
107 return s
110class RHN(Module):
n_layers
depth
それぞれ奥行きのある高速道路ネットワークレイヤーから成るネットワークを作成します。
115 def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120 super().__init__()
121 self.n_layers = n_layers
122 self.hidden_size = hidden_size
レイヤーごとにセルを作成します。最初のレイヤーだけが直接入力を取得することに注意してください。残りのレイヤーは、下のレイヤーから入力を取得します
。125 self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126 [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
x
[seq_len, batch_size, input_size]
state
[batch_size, hidden_size]
形があって形がある
128 def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133 time_steps, batch_size = x.shape[:2]
次の場合にステートを初期化します None
136 if state is None:
137 s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138 else:
142 s = torch.unbind(state)
各タイムステップで最終レイヤーの出力を収集する配列。
145 out = []
タイムステップごとにネットワーク経由で実行
148 for t in range(time_steps):
最初のレイヤーへの入力は入力そのものです
150 inp = x[t]
レイヤーをループする
152 for layer in range(self.n_layers):
レイヤーの状態を取得
154 s[layer] = self.cells[layer](inp, s[layer])
次のレイヤーへの入力は、このレイヤーの状態です
156 inp = s[layer]
最終レイヤーの出力を集める
158 out.append(s[-1])
出力とステートを積み重ねる
161 out = torch.stack(out)
162 s = torch.stack(s)
163
164 return out, s