This is a PyTorch implementation of Recurrent Highway Networks.
11from typing import Optional 12 13import torch 14from torch import nn 15 16from labml_helpers.module import Module
This implements equations .
stands for element-wise multiplication.
Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with , which was in the paper. To avoid confusion with multiple layers we use for depth and for total depth instead of and from the paper.
We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.
We implement weight tying, as described in paper, .
is the feature length of the input and
is the feature length of the cell.
57 def __init__(self, input_size: int, hidden_size: int, depth: int):
63 super().__init__() 64 65 self.hidden_size = hidden_size 66 self.depth = depth
We combine and , with a single linear layer. We can then split the results to get the and components. This is the and for .
70 self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
Similarly we combine and .
73 self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
75 def forward(self, x: torch.Tensor, s: torch.Tensor):
82 for d in range(self.depth):
We calculate the concatenation of linear transforms for and
84 if d == 0:
The input is used only when is .
86 hg = self.input_lin(x) + self.hidden_lin[d](s) 87 else: 88 hg = self.hidden_lin[d](s)
Use the first half of
96 h = torch.tanh(hg[:, :self.hidden_size])
Use the second half of
103 g = torch.sigmoid(hg[:, self.hidden_size:]) 104 105 s = h * g + s * (1 - g) 106 107 return s
Create a network of
of recurrent highway network layers, each with depth
115 def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
120 super().__init__() 121 self.n_layers = n_layers 122 self.hidden_size = hidden_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
125 self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] + 126 [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
[seq_len, batch_size, input_size]
128 def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
133 time_steps, batch_size = x.shape[:2]
Initialize the state if
136 if state is None: 137 s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 138 else:
Reverse stack the state to get the state of each layer
📝 You can just work with the tensor itself but this is easier to debug
142 s = torch.unbind(state)
Array to collect the outputs of the final layer at each time step.
145 out = 
Run through the network for each time step
148 for t in range(time_steps):
Input to the first layer is the input itself
150 inp = x[t]
Loop through the layers
152 for layer in range(self.n_layers):
Get the state of the layer
154 s[layer] = self.cells[layer](inp, s[layer])
Input to the next layer is the state of this layer
156 inp = s[layer]
Collect the output of the final layer
Stack the outputs and states
161 out = torch.stack(out) 162 s = torch.stack(s) 163 164 return out, s