# HyperNetworks - HyperLSTM

We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.

We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code: experiment.py

HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.

## Dynamic HyperNetworks

In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.

In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter the smaller network generates a feature vector and we dynamically compute as a linear transformation of . For instance where is a 3-d tensor parameter and is a tensor-vector multiplication. is usually a linear transformation of the output of the smaller recurrent network.

### Weight scaling instead of computing

Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector . And this transformation requires an even larger weight tensor. That is, when has shape , will be .

To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.

where is a parameter matrix.

We can further optimize this when we compute , as where stands for element-wise multiplication.

73from typing import Optional, Tuple
74
75import torch
76from torch import nn
77
78from labml_helpers.module import Module
79from labml_nn.lstm import LSTMCell

## HyperLSTM Cell

For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.

82class HyperLSTMCell(Module):

input_size is the size of the input , hidden_size is the size of the LSTM, and hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM. n_z is the size of the feature vectors used to alter the LSTM weights.

We use the output of the smaller LSTM to compute , and using linear transformations. We calculate , , and from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.

📝 Since the computation of and are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.

90    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
108        super().__init__()

The input to the hyperLSTM is where is the input and is the output of the outer LSTM at previous step. So the input size is hidden_size + input_size .

The output of hyperLSTM is and .

121        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)

🤔 In the paper it was specified as I feel that it's a typo.

127        self.z_h = nn.Linear(hyper_size, 4 * n_z)
129        self.z_x = nn.Linear(hyper_size, 4 * n_z)
131        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)
134        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
135        self.d_h = nn.ModuleList(d_h)
137        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
138        self.d_x = nn.ModuleList(d_x)
140        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
141        self.d_b = nn.ModuleList(d_b)

The weight matrices

144        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])

The weight matrices

146        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])

Layer normalization

149        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
150        self.layer_norm_c = nn.LayerNorm(hidden_size)
152    def forward(self, x: torch.Tensor,
153                h: torch.Tensor, c: torch.Tensor,
154                h_hat: torch.Tensor, c_hat: torch.Tensor):
161        x_hat = torch.cat((h, x), dim=-1)
163        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
166        z_h = self.z_h(h_hat).chunk(4, dim=-1)
168        z_x = self.z_x(h_hat).chunk(4, dim=-1)
170        z_b = self.z_b(h_hat).chunk(4, dim=-1)

We calculate , , and in a loop

173        ifgo = []
174        for i in range(4):
176            d_h = self.d_h[i](z_h[i])
178            d_x = self.d_x[i](z_x[i])
185            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
186                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
187                self.d_b[i](z_b[i])
188
189            ifgo.append(self.layer_norm[i](y))
192        i, f, g, o = ifgo
195        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
198        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
199
200        return h_next, c_next, h_hat, c_hat

# HyperLSTM module

203class HyperLSTM(Module):

Create a network of n_layers of HyperLSTM.

208    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
213        super().__init__()

Store sizes to initialize state

216        self.n_layers = n_layers
217        self.hidden_size = hidden_size
218        self.hyper_size = hyper_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

222        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
223                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
224                                    range(n_layers - 1)])
• x has shape [n_steps, batch_size, input_size] and
• state is a tuple of . have shape [batch_size, hidden_size] and have shape [batch_size, hyper_size] .
226    def forward(self, x: torch.Tensor,
227                state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
234        n_steps, batch_size = x.shape[:2]

Initialize the state with zeros if None

237        if state is None:
238            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
240            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
241            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
243        else:
244            (h, c, h_hat, c_hat) = state

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

248            h, c = list(torch.unbind(h)), list(torch.unbind(c))
249            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))

Collect the outputs of the final layer at each step

252        out = []
253        for t in range(n_steps):

Input to the first layer is the input itself

255            inp = x[t]

Loop through the layers

257            for layer in range(self.n_layers):

Get the state of the layer

259                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
260                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])

Input to the next layer is the state of this layer

262                inp = h[layer]

Collect the output of the final layer

264            out.append(h[-1])

Stack the outputs and states

267        out = torch.stack(out)
268        h = torch.stack(h)
269        c = torch.stack(c)
270        h_hat = torch.stack(h_hat)
271        c_hat = torch.stack(c_hat)
274        return out, (h, c, h_hat, c_hat)