11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module
这实现了方程。
在哪里
还有为了
代表逐元素乘法。
在这里,我们对论文中的符号进行了一些更改。为了避免与时间混淆,gate 用在报纸上表示。为了避免与多层混淆,我们使用深度和总深度来代替纸张和来自纸张。
我们还用线性变换取代了方程中的权重矩阵和偏置向量,因为这就是实现的样子。
我们实施重量捆绑,如纸中所述。
19class RHNCell(Module):
input_size
是输入的要素长度,hidden_size
是像元的要素长度。depth
是。
57 def __init__(self, input_size: int, hidden_size: int, depth: int):
63 super().__init__()
64
65 self.hidden_size = hidden_size
66 self.depth = depth
我们将和与单个线性层结合起来。然后,我们可以拆分结果以获得和组件。这是 an d f or。
70 self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
同样,我们将和。
73 self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
x
有形状[batch_size, input_size]
和s
形状[batch_size, hidden_size]
。
75 def forward(self, x: torch.Tensor, s: torch.Tensor):
迭代
82 for d in range(self.depth):
我们计算了和的线性变换的级联
84 if d == 0:
在为时才使用输入。
86 hg = self.input_lin(x) + self.hidden_lin[d](s)
87 else:
88 hg = self.hidden_lin[d](s)
96 h = torch.tanh(hg[:, :self.hidden_size])
103 g = torch.sigmoid(hg[:, self.hidden_size:])
104
105 s = h * g + s * (1 - g)
106
107 return s
110class RHN(Module):
创建一个由n_layers
循环高速公路网络图层组成的网络,每个图层的深度depth
为。
115 def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120 super().__init__()
121 self.n_layers = n_layers
122 self.hidden_size = hidden_size
为每层创建单元。请注意,只有第一层直接获得输入。其余图层从下面的图层获取输入
125 self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126 [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
x
有形状[seq_len, batch_size, input_size]
和state
形状[batch_size, hidden_size]
。
128 def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133 time_steps, batch_size = x.shape[:2]
初始化状态如果None
136 if state is None:
137 s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138 else:
142 s = torch.unbind(state)
用于在每个时间步收集最后一层输出的数组。
145 out = []
在网络中运行每个时间步长
148 for t in range(time_steps):
第一层的输入是输入本身
150 inp = x[t]
循环穿过图层
152 for layer in range(self.n_layers):
获取图层的状态
154 s[layer] = self.cells[layer](inp, s[layer])
下一层的输入是该图层的状态
156 inp = s[layer]
收集最后一层的输出
158 out.append(s[-1])
堆叠输出和状态
161 out = torch.stack(out)
162 s = torch.stack(s)
163
164 return out, s