We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.
We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code:
HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.
In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.
In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter the smaller network generates a feature vector and we dynamically compute as a linear transformation of . For instance where is a 3-d tensor parameter and is a tensor-vector multiplication. is usually a linear transformation of the output of the smaller recurrent network.
Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector . And this transformation requires an even larger weight tensor. That is, when has shape , will be .
To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.
where is a parameter matrix.
We can further optimize this when we compute , as where stands for element-wise multiplication.
73from typing import Optional, Tuple 74 75import torch 76from torch import nn 77 78from labml_helpers.module import Module 79from labml_nn.lstm import LSTMCell
For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.
is the size of the input ,
is the size of the LSTM, and
is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
is the size of the feature vectors used to alter the LSTM weights.
We use the output of the smaller LSTM to compute , and using linear transformations. We calculate , , and from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.
📝 Since the computation of and are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.
90 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
The input to the hyperLSTM is where is the input and is the output of the outer LSTM at previous step. So the input size is
hidden_size + input_size
The output of hyperLSTM is and .
121 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)
🤔 In the paper it was specified as I feel that it's a typo.
127 self.z_h = nn.Linear(hyper_size, 4 * n_z)
129 self.z_x = nn.Linear(hyper_size, 4 * n_z)
131 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)
134 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] 135 self.d_h = nn.ModuleList(d_h)
137 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] 138 self.d_x = nn.ModuleList(d_x)
140 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)] 141 self.d_b = nn.ModuleList(d_b)
The weight matrices
144 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
The weight matrices
146 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
149 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) 150 self.layer_norm_c = nn.LayerNorm(hidden_size)
152 def forward(self, x: torch.Tensor, 153 h: torch.Tensor, c: torch.Tensor, 154 h_hat: torch.Tensor, c_hat: torch.Tensor):
161 x_hat = torch.cat((h, x), dim=-1)
163 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
166 z_h = self.z_h(h_hat).chunk(4, dim=-1)
168 z_x = self.z_x(h_hat).chunk(4, dim=-1)
170 z_b = self.z_b(h_hat).chunk(4, dim=-1)
We calculate , , and in a loop
173 ifgo =  174 for i in range(4):
176 d_h = self.d_h[i](z_h[i])
178 d_x = self.d_x[i](z_x[i])
185 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \ 186 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \ 187 self.d_b[i](z_b[i]) 188 189 ifgo.append(self.layer_norm[i](y))
192 i, f, g, o = ifgo
195 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
198 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next)) 199 200 return h_next, c_next, h_hat, c_hat
Create a network of
208 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
Store sizes to initialize state
216 self.n_layers = n_layers 217 self.hidden_size = hidden_size 218 self.hyper_size = hyper_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
222 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] + 223 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in 224 range(n_layers - 1)])
[n_steps, batch_size, input_size]and
stateis a tuple of . have shape
[batch_size, hidden_size]and have shape
226 def forward(self, x: torch.Tensor, 227 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
234 n_steps, batch_size = x.shape[:2]
Initialize the state with zeros if
237 if state is None: 238 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 239 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 240 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)] 241 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
243 else: 244 (h, c, h_hat, c_hat) = state
Reverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
248 h, c = list(torch.unbind(h)), list(torch.unbind(c)) 249 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
Collect the outputs of the final layer at each step
252 out =  253 for t in range(n_steps):
Input to the first layer is the input itself
255 inp = x[t]
Loop through the layers
257 for layer in range(self.n_layers):
Get the state of the layer
259 h[layer], c[layer], h_hat[layer], c_hat[layer] = \ 260 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
Input to the next layer is the state of this layer
262 inp = h[layer]
Collect the output of the final layer
Stack the outputs and states
267 out = torch.stack(out) 268 h = torch.stack(h) 269 c = torch.stack(c) 270 h_hat = torch.stack(h_hat) 271 c_hat = torch.stack(c_hat)
274 return out, (h, c, h_hat, c_hat)