# 循环高速公路网络

11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module

## 循环高速公路网络单元

19class RHNCell(Module):

input_size 是输入的要素长度，hidden_size 是像元的要素长度。depth

57    def __init__(self, input_size: int, hidden_size: int, depth: int):
63        super().__init__()
64
65        self.hidden_size = hidden_size
66        self.depth = depth

70        self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])

73        self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)

x 有形状[batch_size, input_size]s 形状[batch_size, hidden_size]

75    def forward(self, x: torch.Tensor, s: torch.Tensor):

82        for d in range(self.depth):

84            if d == 0:

86                hg = self.input_lin(x) + self.hidden_lin[d](s)
87            else:
88                hg = self.hidden_lin[d](s)

96            h = torch.tanh(hg[:, :self.hidden_size])

103            g = torch.sigmoid(hg[:, self.hidden_size:])
104
105            s = h * g + s * (1 - g)
106
107        return s

## 多层循环高速公路网

110class RHN(Module):

115    def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120        super().__init__()
121        self.n_layers = n_layers
122        self.hidden_size = hidden_size

125        self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126                                   [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])

x 有形状[seq_len, batch_size, input_size]state 形状[batch_size, hidden_size]

128    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133        time_steps, batch_size = x.shape[:2]

136        if state is None:
137            s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138        else:

📝 你可以只使用张量本身，但这更容易调试

142            s = torch.unbind(state)

145        out = []

148        for t in range(time_steps):

150            inp = x[t]

152            for layer in range(self.n_layers):

154                s[layer] = self.cells[layer](inp, s[layer])

156                inp = s[layer]

158            out.append(s[-1])

161        out = torch.stack(out)
162        s = torch.stack(s)
163
164        return out, s