HyperNetworks - HyperLSTM

We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.

We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code: experiment.py

Open In Colab

HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.

Dynamic HyperNetworks

In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.

In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter the smaller network generates a feature vector and we dynamically compute as a linear transformation of . For instance where is a 3-d tensor parameter and is a tensor-vector multiplication. is usually a linear transformation of the output of the smaller recurrent network.

Weight scaling instead of computing

Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector . And this transformation requires an even larger weight tensor. That is, when has shape , will be .

To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.

where is a parameter matrix.

We can further optimize this when we compute , as where stands for element-wise multiplication.

72from typing import Optional, Tuple
73
74import torch
75from torch import nn
76
77from labml_helpers.module import Module
78from labml_nn.lstm import LSTMCell

HyperLSTM Cell

For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.

81class HyperLSTMCell(Module):

input_size is the size of the input , hidden_size is the size of the LSTM, and hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM. n_z is the size of the feature vectors used to alter the LSTM weights.

We use the output of the smaller LSTM to compute , and using linear transformations. We calculate , , and from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.

📝 Since the computation of and are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.

89    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
107        super().__init__()

The input to the hyperLSTM is where is the input and is the output of the outer LSTM at previous step. So the input size is hidden_size + input_size .

The output of hyperLSTM is and .

120        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)

🤔 In the paper it was specified as I feel that it's a typo.

126        self.z_h = nn.Linear(hyper_size, 4 * n_z)

128        self.z_x = nn.Linear(hyper_size, 4 * n_z)

130        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)

133        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
134        self.d_h = nn.ModuleList(d_h)

136        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
137        self.d_x = nn.ModuleList(d_x)

139        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
140        self.d_b = nn.ModuleList(d_b)

The weight matrices

143        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])

The weight matrices

145        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])

Layer normalization

148        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
149        self.layer_norm_c = nn.LayerNorm(hidden_size)
151    def forward(self, x: torch.Tensor,
152                h: torch.Tensor, c: torch.Tensor,
153                h_hat: torch.Tensor, c_hat: torch.Tensor):

160        x_hat = torch.cat((h, x), dim=-1)

162        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)

165        z_h = self.z_h(h_hat).chunk(4, dim=-1)

167        z_x = self.z_x(h_hat).chunk(4, dim=-1)

169        z_b = self.z_b(h_hat).chunk(4, dim=-1)

We calculate , , and in a loop

172        ifgo = []
173        for i in range(4):

175            d_h = self.d_h[i](z_h[i])

177            d_x = self.d_x[i](z_x[i])

184            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
185                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
186                self.d_b[i](z_b[i])
187
188            ifgo.append(self.layer_norm[i](y))

191        i, f, g, o = ifgo

194        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

197        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
198
199        return h_next, c_next, h_hat, c_hat

HyperLSTM module

202class HyperLSTM(Module):

Create a network of n_layers of HyperLSTM.

207    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
212        super().__init__()

Store sizes to initialize state

215        self.n_layers = n_layers
216        self.hidden_size = hidden_size
217        self.hyper_size = hyper_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

221        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
222                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
223                                    range(n_layers - 1)])
  • x has shape [n_steps, batch_size, input_size] and
  • state is a tuple of . have shape [batch_size, hidden_size] and have shape [batch_size, hyper_size] .
225    def forward(self, x: torch.Tensor,
226                state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
233        n_steps, batch_size = x.shape[:2]

Initialize the state with zeros if None

236        if state is None:
237            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
238            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
240            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]

242        else:
243            (h, c, h_hat, c_hat) = state

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

247            h, c = list(torch.unbind(h)), list(torch.unbind(c))
248            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))

Collect the outputs of the final layer at each step

251        out = []
252        for t in range(n_steps):

Input to the first layer is the input itself

254            inp = x[t]

Loop through the layers

256            for layer in range(self.n_layers):

Get the state of the layer

258                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
259                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])

Input to the next layer is the state of this layer

261                inp = h[layer]

Collect the output of the final layer

263            out.append(h[-1])

Stack the outputs and states

266        out = torch.stack(out)
267        h = torch.stack(h)
268        c = torch.stack(c)
269        h_hat = torch.stack(h_hat)
270        c_hat = torch.stack(c_hat)

273        return out, (h, c, h_hat, c_hat)