Recurrent Highway Networks

This is a PyTorch implementation of Recurrent Highway Networks.

11from typing import Optional
12
13import torch
14from torch import nn

Recurrent Highway Network Cell

This implements equations .

where

and for

stands for element-wise multiplication.

Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with , which was in the paper. To avoid confusion with multiple layers we use for depth and for total depth instead of and from the paper.

We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.

We implement weight tying, as described in paper, .

18class RHNCell(nn.Module):

input_size is the feature length of the input and hidden_size is the feature length of the cell. depth is .

56    def __init__(self, input_size: int, hidden_size: int, depth: int):
62        super().__init__()
63
64        self.hidden_size = hidden_size
65        self.depth = depth

We combine and , with a single linear layer. We can then split the results to get the and components. This is the and for .

69        self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])

Similarly we combine and .

72        self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)

x has shape [batch_size, input_size] and s has shape [batch_size, hidden_size] .

74    def forward(self, x: torch.Tensor, s: torch.Tensor):

Iterate

81        for d in range(self.depth):

We calculate the concatenation of linear transforms for and

83            if d == 0:

The input is used only when is .

85                hg = self.input_lin(x) + self.hidden_lin[d](s)
86            else:
87                hg = self.hidden_lin[d](s)

Use the first half of hg to get

95            h = torch.tanh(hg[:, :self.hidden_size])

Use the second half of hg to get

102            g = torch.sigmoid(hg[:, self.hidden_size:])
103
104            s = h * g + s * (1 - g)
105
106        return s

Multilayer Recurrent Highway Network

109class RHN(nn.Module):

Create a network of n_layers of recurrent highway network layers, each with depth depth , .

114    def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
119        super().__init__()
120        self.n_layers = n_layers
121        self.hidden_size = hidden_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

124        self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
125                                   [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])

x has shape [seq_len, batch_size, input_size] and state has shape [batch_size, hidden_size] .

127    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
132        time_steps, batch_size = x.shape[:2]

Initialize the state if None

135        if state is None:
136            s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
137        else:

Reverse stack the state to get the state of each layer

📝 You can just work with the tensor itself but this is easier to debug

141            s = torch.unbind(state)

Array to collect the outputs of the final layer at each time step.

144        out = []

Run through the network for each time step

147        for t in range(time_steps):

Input to the first layer is the input itself

149            inp = x[t]

Loop through the layers

151            for layer in range(self.n_layers):

Get the state of the layer

153                s[layer] = self.cells[layer](inp, s[layer])

Input to the next layer is the state of this layer

155                inp = s[layer]

Collect the output of the final layer

157            out.append(s[-1])

Stack the outputs and states

160        out = torch.stack(out)
161        s = torch.stack(s)
162
163        return out, s