We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.
We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset.
Here’s the link to code: experiment.py
HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.
In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.
In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let’s say the larger network has some parameter $\color{cyan}{W_h}$ the smaller network generates a feature vector $z_h$ and we dynamically compute $\color{cyan}{W_h}$ as a linear transformation of $z_h$. For instance $\color{cyan}{W_h} = \langle W_{hz}, z_h \rangle$ where $W_{hz}$ is a 3-d tensor parameter and $\langle . \rangle$ is a tensor-vector multiplication. $z_h$ is usually a linear transformation of the output of the smaller recurrent network.
Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector $z$. And this transformation requires an even larger weight tensor. That is, when $\color{cyan}{W_h}$ has shape $N_h \times N_h$, $W_{hz}$ will be $N_h \times N_h \times N_z$.
To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size. where $W_{hd}$ is a $N_h \times N_h$ parameter matrix.
We can further optimize this when we compute $\color{cyan}{W_h} h$, as where $\odot$ stands for element-wise multiplication.
71from typing import Optional, Tuple
72
73import torch
74from torch import nn
75
76from labml_helpers.module import Module
77from labml_nn.lstm import LSTMCell
For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.
80class HyperLSTMCell(Module):
input_size
is the size of the input $x_t$,
hidden_size
is the size of the LSTM, and
hyper_size
is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
n_z
is the size of the feature vectors used to alter the LSTM weights.
We use the output of the smaller LSTM to compute $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and $z_b^{i,f,g,o}$ using linear transformations. We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.
📝 Since the computation of $z$ and $d$ are two sequential linear transformations these can be combined into a single linear transformation. However we’ve implemented this separately so that it matches with the description in the paper.
88 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
106 super().__init__()
The input to the hyperLSTM is
where $x_t$ is the input and $h_{t-1}$ is the output of the outer LSTM at previous step.
So the input size is hidden_size + input_size
.
The output of hyperLSTM is $\hat{h}_t$ and $\hat{c}_t$.
119 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)
🤔 In the paper it was specified as I feel that it’s a typo.
125 self.z_h = nn.Linear(hyper_size, 4 * n_z)
127 self.z_x = nn.Linear(hyper_size, 4 * n_z)
129 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)
132 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
133 self.d_h = nn.ModuleList(d_h)
135 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
136 self.d_x = nn.ModuleList(d_x)
138 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
139 self.d_b = nn.ModuleList(d_b)
The weight matrices $W_h^{i,f,g,o}$
142 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
The weight matrices $W_x^{i,f,g,o}$
144 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
Layer normalization
147 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
148 self.layer_norm_c = nn.LayerNorm(hidden_size)
150 def __call__(self, x: torch.Tensor,
151 h: torch.Tensor, c: torch.Tensor,
152 h_hat: torch.Tensor, c_hat: torch.Tensor):
159 x_hat = torch.cat((h, x), dim=-1)
161 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
164 z_h = self.z_h(h_hat).chunk(4, dim=-1)
166 z_x = self.z_x(h_hat).chunk(4, dim=-1)
168 z_b = self.z_b(h_hat).chunk(4, dim=-1)
We calculate $i$, $f$, $g$ and $o$ in a loop
171 ifgo = []
172 for i in range(4):
174 d_h = self.d_h[i](z_h[i])
176 d_x = self.d_x[i](z_x[i])
183 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
184 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
185 self.d_b[i](z_b[i])
186
187 ifgo.append(self.layer_norm[i](y))
190 i, f, g, o = ifgo
193 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
196 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
197
198 return h_next, c_next, h_hat, c_hat
201class HyperLSTM(Module):
Create a network of n_layers
of HyperLSTM.
205 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
210 super().__init__()
Store sizes to initialize state
213 self.n_layers = n_layers
214 self.hidden_size = hidden_size
215 self.hyper_size = hyper_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
219 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
220 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
221 range(n_layers - 1)])
x
has shape [n_steps, batch_size, input_size]
andstate
is a tuple of $h, c, \hat{h}, \hat{c}$.
$h, c$ have shape [batch_size, hidden_size]
and
$\hat{h}, \hat{c}$ have shape [batch_size, hyper_size]
.223 def __call__(self, x: torch.Tensor,
224 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
231 n_steps, batch_size = x.shape[:2]
Initialize the state with zeros if None
234 if state is None:
235 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
236 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
237 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
238 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
240 else:
241 (h, c, h_hat, c_hat) = state
Reverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
245 h, c = list(torch.unbind(h)), list(torch.unbind(c))
246 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
Collect the outputs of the final layer at each step
249 out = []
250 for t in range(n_steps):
Input to the first layer is the input itself
252 inp = x[t]
Loop through the layers
254 for layer in range(self.n_layers):
Get the state of the layer
256 h[layer], c[layer], h_hat[layer], c_hat[layer] = \
257 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
Input to the next layer is the state of this layer
259 inp = h[layer]
Collect the output $h$ of the final layer
261 out.append(h[-1])
Stack the outputs and states
264 out = torch.stack(out)
265 h = torch.stack(h)
266 c = torch.stack(c)
267 h_hat = torch.stack(h_hat)
268 c_hat = torch.stack(c_hat)
271 return out, (h, c, h_hat, c_hat)