11from typing import Optional
12
13import torch
14from torch import nn
This implements equations .
where
and for
stands for element-wise multiplication.
Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with , which was in the paper. To avoid confusion with multiple layers we use for depth and for total depth instead of and from the paper.
We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.
We implement weight tying, as described in paper, .
18class RHNCell(nn.Module):
input_size
is the feature length of the input and hidden_size
is the feature length of the cell. depth
is .
56 def __init__(self, input_size: int, hidden_size: int, depth: int):
62 super().__init__()
63
64 self.hidden_size = hidden_size
65 self.depth = depth
We combine and , with a single linear layer. We can then split the results to get the and components. This is the and for .
69 self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
Similarly we combine and .
72 self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
x
has shape [batch_size, input_size]
and s
has shape [batch_size, hidden_size]
.
74 def forward(self, x: torch.Tensor, s: torch.Tensor):
Iterate
81 for d in range(self.depth):
We calculate the concatenation of linear transforms for and
83 if d == 0:
The input is used only when is .
85 hg = self.input_lin(x) + self.hidden_lin[d](s)
86 else:
87 hg = self.hidden_lin[d](s)
95 h = torch.tanh(hg[:, :self.hidden_size])
102 g = torch.sigmoid(hg[:, self.hidden_size:])
103
104 s = h * g + s * (1 - g)
105
106 return s
109class RHN(nn.Module):
Create a network of n_layers
of recurrent highway network layers, each with depth depth
, .
114 def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
119 super().__init__()
120 self.n_layers = n_layers
121 self.hidden_size = hidden_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
124 self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
125 [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
x
has shape [seq_len, batch_size, input_size]
and state
has shape [batch_size, hidden_size]
.
127 def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
132 time_steps, batch_size = x.shape[:2]
Initialize the state if None
135 if state is None:
136 s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
137 else:
Reverse stack the state to get the state of each layer
📝 You can just work with the tensor itself but this is easier to debug
141 s = torch.unbind(state)
Array to collect the outputs of the final layer at each time step.
144 out = []
Run through the network for each time step
147 for t in range(time_steps):
Input to the first layer is the input itself
149 inp = x[t]
Loop through the layers
151 for layer in range(self.n_layers):
Get the state of the layer
153 s[layer] = self.cells[layer](inp, s[layer])
Input to the next layer is the state of this layer
155 inp = s[layer]
Collect the output of the final layer
157 out.append(s[-1])
Stack the outputs and states
160 out = torch.stack(out)
161 s = torch.stack(s)
162
163 return out, s