长期短期记忆 (LSTM)

这是长短期记忆的 PyTorch 实现。

12from typing import Optional, Tuple
13
14import torch
15from torch import nn
16
17from labml_helpers.module import Module

长短期记忆细胞

LSTM 细胞计算、和就像长期记忆,就像短期记忆。我们使用输入来更新长期内存。在更新中,的某些功能使用忘记门清除,而某些功能则通过大门添加

新的短期存储器是长期存储器乘以输出门的值

请注意,在进行更新时,单元格不会查看长期内存。它只会修改它。也永远不会经历线性变换。这就是解决渐变消失和爆炸的问题。

以下是更新规则。

代表逐元素乘法。

中间值和门计算为隐藏状态和输入的线性变换。

20class LSTMCell(Module):
57    def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False):
58        super().__init__()

这些是用于变换input 和向hidden 量的线性图层。其中一个不需要偏差,因为我们添加了变换。

这结合、和转换。

64        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)

这结合、和转换。

66        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)

是否应用层归一化。

应用图层归一化可以获得更好的结果。嵌入是标准化的,并在

73        if layer_norm:
74            self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
75            self.layer_norm_c = nn.LayerNorm(hidden_size)
76        else:
77            self.layer_norm = nn.ModuleList([nn.Identity() for _ in range(4)])
78            self.layer_norm_c = nn.Identity()
80    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):

我们使用相同的线性层计算和的线性变换。

83        ifgo = self.hidden_lin(h) + self.input_lin(x)

每层产生 4 倍的输出hidden_size ,我们将其拆分

85        ifgo = ifgo.chunk(4, dim=-1)

应用图层归一化(不在原始纸张中,但效果更好)

88        ifgo = [self.layer_norm[i](ifgo[i]) for i in range(4)]

91        i, f, g, o = ifgo

94        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

或者,将图层规范应用于

98        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
99
100        return h_next, c_next

多层 LSTM

103class LSTM(Module):

创建一个由 LSTMn_layers 组成的网络。

108    def __init__(self, input_size: int, hidden_size: int, n_layers: int):
113        super().__init__()
114        self.n_layers = n_layers
115        self.hidden_size = hidden_size

为每层创建单元。请注意,只有第一层直接获得输入。其余图层从下面的图层获取输入

118        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
119                                   [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])

x 有形状[n_steps, batch_size, input_size] 并且state和的元组,每个元组的形状为[batch_size, hidden_size]

121    def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
126        n_steps, batch_size = x.shape[:2]

初始化状态如果None

129        if state is None:
130            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
131            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
132        else:
133            (h, c) = state

反向堆叠张量以获得每层的状态

📝 你可以只使用张量本身,但这更容易调试

137            h, c = list(torch.unbind(h)), list(torch.unbind(c))

用于在每个时间步收集最后一层输出的数组。

140        out = []
141        for t in range(n_steps):

第一层的输入是输入本身

143            inp = x[t]

循环穿过图层

145            for layer in range(self.n_layers):

获取图层的状态

147                h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])

下一层的输入是该图层的状态

149                inp = h[layer]

收集最后一层的输出

151            out.append(h[-1])

堆叠输出和状态

154        out = torch.stack(out)
155        h = torch.stack(h)
156        c = torch.stack(c)
157
158        return out, (h, c)