長期短期記憶 (LSTM)

これは長短期記憶のPyTorch実装です

12from typing import Optional, Tuple
13
14import torch
15from torch import nn
16
17from labml_helpers.module import Module

長短期メモリーセル

LSTM セルが計算し、 長期記憶のようなもので、短期記憶のようなものです。入力とを使用して長期記憶を更新します。今回のアップデートでは、一部のフィーチャはフォーゲートゲートでクリアされ一部のフィーチャはゲート経由で追加されます

新しい短期記憶は、長期記憶に出力ゲートを掛けたものです。

更新を行うとき、セルは長期記憶を見ないことに注意してください。変更するだけです。また、線形変換を行うこともありません。これがグラデーションの消失と爆発を解消するものです

更新ルールは次のとおりです。

要素ごとの乗算の略です。

中間値とゲートは、隠れ状態と入力の線形変換として計算されます。

20class LSTMCell(Module):
57    def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False):
58        super().__init__()

これらは、input hidden およびベクトルを変換する線形レイヤーです。そのうちの 1 つは、変換を追加するのでバイアスを必要としません

これにより、、変換が組み合わされます。

64        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)

これにより、、変換が組み合わされます。

66        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)

レイヤー正規化を適用するかどうか。

レイヤーの正規化を適用すると、より良い結果が得られます。および埋め込みは正規化され、次の形式で正規化されます

73        if layer_norm:
74            self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
75            self.layer_norm_c = nn.LayerNorm(hidden_size)
76        else:
77            self.layer_norm = nn.ModuleList([nn.Identity() for _ in range(4)])
78            self.layer_norm_c = nn.Identity()
80    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):

、、の線形変換は同じ線形層を使用して計算します。

83        ifgo = self.hidden_lin(h) + self.input_lin(x)

各レイヤーは 4 hidden_size 倍の出力を生成し、それらを分割します

85        ifgo = ifgo.chunk(4, dim=-1)

レイヤーの正規化を適用(元の用紙には使用しないが、より良い結果になる)

88        ifgo = [self.layer_norm[i](ifgo[i]) for i in range(4)]

91        i, f, g, o = ifgo

94        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

オプションでレイヤーノルムを適用

98        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
99
100        return h_next, c_next

マルチレイヤーLSTM

103class LSTM(Module):

LSTM n_layers のネットワークを作成します。

108    def __init__(self, input_size: int, hidden_size: int, n_layers: int):
113        super().__init__()
114        self.n_layers = n_layers
115        self.hidden_size = hidden_size

レイヤーごとにセルを作成します。最初のレイヤーだけが直接入力を取得することに注意してください。残りのレイヤーは、下のレイヤーから入力を取得します

118        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
119                                   [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])

x [n_steps, batch_size, input_size] state の形状はとが組み合わさったもので[batch_size, hidden_size] それぞれの形状はです。

121    def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
126        n_steps, batch_size = x.shape[:2]

次の場合にステートを初期化します None

129        if state is None:
130            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
131            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
132        else:
133            (h, c) = state

テンソルを逆に積み重ねて各レイヤーの状態を取得します

📝 テンソル自体で作業することもできますが、デバッグは簡単です

137            h, c = list(torch.unbind(h)), list(torch.unbind(c))

各タイムステップで最終レイヤーの出力を収集する配列。

140        out = []
141        for t in range(n_steps):

最初のレイヤーへの入力は入力そのものです

143            inp = x[t]

レイヤーをループする

145            for layer in range(self.n_layers):

レイヤーの状態を取得

147                h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])

次のレイヤーへの入力は、このレイヤーの状態です

149                inp = h[layer]

最終レイヤーの出力を集める

151            out.append(h[-1])

出力とステートを積み重ねる

154        out = torch.stack(out)
155        h = torch.stack(h)
156        c = torch.stack(c)
157
158        return out, (h, c)