```
11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module
```

This implements equations $(6)−(9)$.

$s_{d}=h_{d}⊙g_{d}+s_{d−1}⊙c_{d}$

where

$h_{0}g_{0}c_{0} =tanh(lin_{hx}(x)+lin_{hs}(s_{D}))=σ(lin_{gx}(x)+lin_{gs}(s_{D}))=σ(lin_{cx}(x)+lin_{cs}(s_{D})) $and for $0<d<D$

$h_{d}g_{d}c_{d} =tanh(lin_{hs}(s_{d}))=σ(lin_{gs}(s_{d}))=σ(lin_{cs}(s_{d})) $$⊙$ stands for element-wise multiplication.

Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with $g$, which was $t$ in the paper. To avoid confusion with multiple layers we use $d$ for depth and $D$ for total depth instead of $l$ and $L$ from the paper.

We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.

We implement weight tying, as described in paper, $c_{d}=1−g_{d}$.

`19class RHNCell(Module):`

`input_size`

is the feature length of the input and `hidden_size`

is the feature length of the cell. `depth`

is $D$.

`57 def __init__(self, input_size: int, hidden_size: int, depth: int):`

```
63 super().__init__()
64
65 self.hidden_size = hidden_size
66 self.depth = depth
```

We combine $lin_{hs}$ and $lin_{gs}$, with a single linear layer. We can then split the results to get the $lin_{hs}$ and $lin_{gs}$ components. This is the $lin_{hs}$ and $lin_{gs}$ for $0≤d<D$.

`70 self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])`

Similarly we combine $lin_{hx}$ and $lin_{gx}$.

`73 self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)`

`x`

has shape `[batch_size, input_size]`

and `s`

has shape `[batch_size, hidden_size]`

.

`75 def forward(self, x: torch.Tensor, s: torch.Tensor):`

Iterate $0≤d<D$

`82 for d in range(self.depth):`

We calculate the concatenation of linear transforms for $h$ and $g$

`84 if d == 0:`

The input is used only when $d$ is $0$.

```
86 hg = self.input_lin(x) + self.hidden_lin[d](s)
87 else:
88 hg = self.hidden_lin[d](s)
```

Use the first half of `hg`

to get $h_{d}$

`96 h = torch.tanh(hg[:, :self.hidden_size])`

Use the second half of `hg`

to get $g_{d}$

```
103 g = torch.sigmoid(hg[:, self.hidden_size:])
104
105 s = h * g + s * (1 - g)
106
107 return s
```

`110class RHN(Module):`

Create a network of `n_layers`

of recurrent highway network layers, each with depth `depth`

, $D$.

`115 def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):`

```
120 super().__init__()
121 self.n_layers = n_layers
122 self.hidden_size = hidden_size
```

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

```
125 self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
126 [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
```

`x`

has shape `[seq_len, batch_size, input_size]`

and `state`

has shape `[batch_size, hidden_size]`

.

`128 def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):`

`133 time_steps, batch_size = x.shape[:2]`

Initialize the state if `None`

```
136 if state is None:
137 s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
138 else:
```

Reverse stack the state to get the state of each layer

📝 You can just work with the tensor itself but this is easier to debug

`142 s = torch.unbind(state)`

Array to collect the outputs of the final layer at each time step.

`145 out = []`

Run through the network for each time step

`148 for t in range(time_steps):`

Input to the first layer is the input itself

`150 inp = x[t]`

Loop through the layers

`152 for layer in range(self.n_layers):`

Get the state of the layer

`154 s[layer] = self.cells[layer](inp, s[layer])`

Input to the next layer is the state of this layer

`156 inp = s[layer]`

Collect the output of the final layer

`158 out.append(s[-1])`

Stack the outputs and states

```
161 out = torch.stack(out)
162 s = torch.stack(s)
163
164 return out, s
```