循环高速公路网络

这是循环高速公路网络PyTorch 实现。

11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module

循环高速公路网络单元

这实现了方程

在哪里

还有为了

代表逐元素乘法。

在这里,我们对论文中的符号进行了一些更改。为了避免与时间混淆,gate 用在报纸上表示。为了避免与多层混淆,我们使用深度和总深度来代替纸张来自纸张。

我们还用线性变换取代了方程中的权重矩阵和偏置向量,因为这就是实现的样子。

我们实施重量捆绑,如纸中所述

19class RHNCell(Module):

input_size 是输入的要素长度,hidden_size 是像元的要素长度。depth

57    def __init__(self, input_size: int, hidden_size: int, depth: int):
63        super().__init__()
64
65        self.hidden_size = hidden_size
66        self.depth = depth

我们将与单个线性层结合起来。然后,我们可以拆分结果以获得和组件。这是 an d f or

70        self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])

同样,我们将

73        self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)

x 有形状[batch_size, input_size]s 形状[batch_size, hidden_size]

75    def forward(self, x: torch.Tensor, s: torch.Tensor):

迭代

82        for d in range(self.depth):

我们计算了和的线性变换的级联

84            if d == 0:
只有@@

在为时才使用输入

86                hg = self.input_lin(x) + self.hidden_lin[d](s)
87            else:
88                hg = self.hidden_lin[d](s)

使用的前半部分hg 获得

96            h = torch.tanh(hg[:, :self.hidden_size])

使用后半部分hg 获取

103            g = torch.sigmoid(hg[:, self.hidden_size:])
104
105            s = h * g + s * (1 - g)
106
107        return s

多层循环高速公路网

110class RHN(Module):

创建一个由n_layers 循环高速公路网络图层组成的网络,每个图层的深度depth

115    def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120        super().__init__()
121        self.n_layers = n_layers
122        self.hidden_size = hidden_size

为每层创建单元。请注意,只有第一层直接获得输入。其余图层从下面的图层获取输入

125        self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126                                   [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])

x 有形状[seq_len, batch_size, input_size]state 形状[batch_size, hidden_size]

128    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133        time_steps, batch_size = x.shape[:2]

初始化状态如果None

136        if state is None:
137            s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138        else:

反向堆叠状态以获取每层的状态

📝 你可以只使用张量本身,但这更容易调试

142            s = torch.unbind(state)

用于在每个时间步收集最后一层输出的数组。

145        out = []

在网络中运行每个时间步长

148        for t in range(time_steps):

第一层的输入是输入本身

150            inp = x[t]

循环穿过图层

152            for layer in range(self.n_layers):

获取图层的状态

154                s[layer] = self.cells[layer](inp, s[layer])

下一层的输入是该图层的状态

156                inp = s[layer]

收集最后一层的输出

158            out.append(s[-1])

堆叠输出和状态

161        out = torch.stack(out)
162        s = torch.stack(s)
163
164        return out, s