We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.

We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code: `experiment.py`

HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.

In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.

In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter $W_{h}$ the smaller network generates a feature vector $z_{h}$ and we dynamically compute $W_{h}$ as a linear transformation of $z_{h}$. For instance $W_{h}=⟨W_{hz},z_{h}⟩$ where $W_{hz}$ is a 3-d tensor parameter and $⟨.⟩$ is a tensor-vector multiplication. $z_{h}$ is usually a linear transformation of the output of the smaller recurrent network.

Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector $z$. And this transformation requires an even larger weight tensor. That is, when $W_{h}$ has shape $N_{h}×N_{h}$, $W_{hz}$ will be $N_{h}×N_{h}×N_{z}$.

To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.

$d(z)=W_{hz}z_{h}W_{h}=⎝⎛ d_{0}(z)W_{hd_{0}}d_{1}(z)W_{hd_{1}}...d_{N_{h}}(z)W_{hd_{N}} ⎠⎞ $where $W_{hd}$ is a $N_{h}×N_{h}$ parameter matrix.

We can further optimize this when we compute $W_{h}h$, as $d(z)⊙(W_{hd}h)$ where $⊙$ stands for element-wise multiplication.

```
73from typing import Optional, Tuple
74
75import torch
76from torch import nn
77
78from labml_helpers.module import Module
79from labml_nn.lstm import LSTMCell
```

For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.

`82class HyperLSTMCell(Module):`

`input_size`

is the size of the input $x_{t}$, `hidden_size`

is the size of the LSTM, and `hyper_size`

is the size of the smaller LSTM that alters the weights of the larger outer LSTM. `n_z`

is the size of the feature vectors used to alter the LSTM weights.

We use the output of the smaller LSTM to compute $z_{h}_{i,f,g,o}$, $z_{x}$ and $z_{b}$ using linear transformations. We calculate $d_{h}(z_{h}_{i,f,g,o})$, $d_{x}(z_{x})$, and $d_{b}(z_{b})$ from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.

📝 Since the computation of $z$ and $d$ are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.

`90 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):`

`108 super().__init__()`

The input to the hyperLSTM is $x^_{t}=(h_{t−1}x_{t} )$ where $x_{t}$ is the input and $h_{t−1}$ is the output of the outer LSTM at previous step. So the input size is `hidden_size + input_size`

.

The output of hyperLSTM is $h^_{t}$ and $c^_{t}$.

`121 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)`

$z_{h}_{i,f,g,o}=lin_{h}(h^_{t})$ 🤔 In the paper it was specified as $z_{h}_{i,f,g,o}=lin_{h}(h^_{t−1})$ I feel that it's a typo.

`127 self.z_h = nn.Linear(hyper_size, 4 * n_z)`

$z_{x}=lin_{x}(h^_{t})$

`129 self.z_x = nn.Linear(hyper_size, 4 * n_z)`

$z_{b}=lin_{b}(h^_{t})$

`131 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)`

$d_{h}(z_{h}_{i,f,g,o})=lin_{dh}(z_{h}_{i,f,g,o})$

```
134 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
135 self.d_h = nn.ModuleList(d_h)
```

$d_{x}(z_{x})=lin_{dx}(z_{x})$

```
137 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
138 self.d_x = nn.ModuleList(d_x)
```

$d_{b}(z_{b})=lin_{db}(z_{b})$

```
140 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
141 self.d_b = nn.ModuleList(d_b)
```

The weight matrices $W_{h}$

`144 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])`

The weight matrices $W_{x}$

`146 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])`

Layer normalization

```
149 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
150 self.layer_norm_c = nn.LayerNorm(hidden_size)
```

```
152 def forward(self, x: torch.Tensor,
153 h: torch.Tensor, c: torch.Tensor,
154 h_hat: torch.Tensor, c_hat: torch.Tensor):
```

$x^_{t}=(h_{t−1}x_{t} )$

`161 x_hat = torch.cat((h, x), dim=-1)`

$h^_{t},c^_{t}=lstm(x^_{t},h^_{t−1},c^_{t−1})$

`163 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)`

$z_{h}_{i,f,g,o}=lin_{h}(h^_{t})$

`166 z_h = self.z_h(h_hat).chunk(4, dim=-1)`

$z_{x}=lin_{x}(h^_{t})$

`168 z_x = self.z_x(h_hat).chunk(4, dim=-1)`

$z_{b}=lin_{b}(h^_{t})$

`170 z_b = self.z_b(h_hat).chunk(4, dim=-1)`

We calculate $i$, $f$, $g$ and $o$ in a loop

```
173 ifgo = []
174 for i in range(4):
```

$d_{h}(z_{h}_{i,f,g,o})=lin_{dh}(z_{h}_{i,f,g,o})$

`176 d_h = self.d_h[i](z_h[i])`

$d_{x}(z_{x})=lin_{dx}(z_{x})$

`178 d_x = self.d_x[i](z_x[i])`

```
185 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
186 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
187 self.d_b[i](z_b[i])
188
189 ifgo.append(self.layer_norm[i](y))
```

$i_{t},f_{t},g_{t},o_{t}$

`192 i, f, g, o = ifgo`

$c_{t}=σ(f_{t})⊙c_{t−1}+σ(i_{t})⊙tanh(g_{t})$

`195 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)`

$h_{t}=σ(o_{t})⊙tanh(LN(c_{t}))$

```
198 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
199
200 return h_next, c_next, h_hat, c_hat
```

`203class HyperLSTM(Module):`

Create a network of `n_layers`

of HyperLSTM.

`208 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):`

`213 super().__init__()`

Store sizes to initialize state

```
216 self.n_layers = n_layers
217 self.hidden_size = hidden_size
218 self.hyper_size = hyper_size
```

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

```
222 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
223 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
224 range(n_layers - 1)])
```

`x`

has shape`[n_steps, batch_size, input_size]`

and`state`

is a tuple of $h,c,h^,c^$. $h,c$ have shape`[batch_size, hidden_size]`

and $h^,c^$ have shape`[batch_size, hyper_size]`

.

```
226 def forward(self, x: torch.Tensor,
227 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
```

`234 n_steps, batch_size = x.shape[:2]`

Initialize the state with zeros if `None`

```
237 if state is None:
238 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
240 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
241 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
```

```
243 else:
244 (h, c, h_hat, c_hat) = state
```

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

```
248 h, c = list(torch.unbind(h)), list(torch.unbind(c))
249 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
```

Collect the outputs of the final layer at each step

```
252 out = []
253 for t in range(n_steps):
```

Input to the first layer is the input itself

`255 inp = x[t]`

Loop through the layers

`257 for layer in range(self.n_layers):`

Get the state of the layer

```
259 h[layer], c[layer], h_hat[layer], c_hat[layer] = \
260 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
```

Input to the next layer is the state of this layer

`262 inp = h[layer]`

Collect the output $h$ of the final layer

`264 out.append(h[-1])`

Stack the outputs and states

```
267 out = torch.stack(out)
268 h = torch.stack(h)
269 c = torch.stack(c)
270 h_hat = torch.stack(h_hat)
271 c_hat = torch.stack(c_hat)
```

`274 return out, (h, c, h_hat, c_hat)`