超网络-HyperLSTM

我们已经实现了论文 Hyper Networks 中介绍的 Hyper LSTM,并使用 PyTorch 进行了注释。David Ha的这篇博客文章很好地解释了HyperNetworks。

我们有一个实验可以训练 HyperLSTM 来预测莎士比亚数据集上的文本。以下是代码链接:experiment.py

Open In Colab

HyperNetworks 使用较小的网络来生成较大网络的权重。有两种变体:静态超网络和动态超网络。静态超网络具有较小的网络,用于生成卷积网络的权重(内核)。动态超网络为每个步骤生成循环神经网络的参数。这是后者的实现。

动态超网络

在 RNN 中,每个步骤的参数保持不变。动态超网络为每个步骤生成不同的参数。HyperLSTM 具有 LSTM 的结构,但每个步骤的参数都由较小的 LSTM 网络更改。

在基本形式中,Dynamic HyperNetwork 具有较小的循环网络,该网络生成与较大循环网络的每个参数张量对应的特征向量。假设较大的网络有一些参数,较小的网络生成一个特征向量,我们动态计算为的线性变换。例如,其中是三维张量参数,是张量向量乘法。通常是较小的循环网络输出的线性变换。

按重量缩放而不是计算

大型循环网络具有大量的动态计算参数。这些是使用特征向量的线性变换计算的。而且这种变换需要更大的权重张量。也就是说,当有形状时将是

为了克服这个问题,我们通过动态缩放相同大小的矩阵的每一行来计算循环网络的权重参数。

其中参数矩阵。

我们可以在计算时进一步对其进行优化,因为其中代表逐元素乘法。

72from typing import Optional, Tuple
73
74import torch
75from torch import nn
76
77from labml_helpers.module import Module
78from labml_nn.lstm import LSTMCell

HyperLSTM Cell

对于 HyperLSTM,较小的网络和较大的网络都具有 LSTM 结构。这在白皮书的附录A.2.2中进行了定义。

81class HyperLSTMCell(Module):

input_size 是输入的大小hidden_size 是 LSTM 的大小,hyper_size 是较小的 LSTM 的大小,它会改变更大的外部 LSTM。n_z 是用于改变 LSTM 权重的特征向量的大小。

我们使用较小的 LSTM 的输出进行计算使用线性变换。我们再次使用线性变换进行计算、和计算。然后使用它们来缩放主 LSTM 的权重和偏置张量的行。

📝 由于和的计算是两个连续的线性变换,因此可以将它们组合成单个线性变换。但是,我们已经单独实现了这一点,以便它与论文中的描述相匹配。

89    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
107        super().__init__()

HyperLSTM 的输入是上一步中外部 LSTM 的输入,也是外部 LSTM 的输出。因此,输入大小为hidden_size + input_size

HyperLSTM 的输出为

120        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)

🤔 在报纸上指定了它,因为我觉得这是一个错字。

126        self.z_h = nn.Linear(hyper_size, 4 * n_z)

128        self.z_x = nn.Linear(hyper_size, 4 * n_z)

130        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)

133        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
134        self.d_h = nn.ModuleList(d_h)

136        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
137        self.d_x = nn.ModuleList(d_x)

139        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
140        self.d_b = nn.ModuleList(d_b)

权重矩阵

143        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])

权重矩阵

145        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])

层规范化

148        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
149        self.layer_norm_c = nn.LayerNorm(hidden_size)
151    def forward(self, x: torch.Tensor,
152                h: torch.Tensor, c: torch.Tensor,
153                h_hat: torch.Tensor, c_hat: torch.Tensor):

160        x_hat = torch.cat((h, x), dim=-1)

162        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)

165        z_h = self.z_h(h_hat).chunk(4, dim=-1)

167        z_x = self.z_x(h_hat).chunk(4, dim=-1)

169        z_b = self.z_b(h_hat).chunk(4, dim=-1)

我们循环计算

172        ifgo = []
173        for i in range(4):

175            d_h = self.d_h[i](z_h[i])

177            d_x = self.d_x[i](z_x[i])

184            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
185                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
186                self.d_b[i](z_b[i])
187
188            ifgo.append(self.layer_norm[i](y))

191        i, f, g, o = ifgo

194        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

197        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
198
199        return h_next, c_next, h_hat, c_hat

HyperLSTM 模块

202class HyperLSTM(Module):

创建一个由 HyperLSTMn_layers 组成的网络。

207    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
212        super().__init__()

存储大小以初始化状态

215        self.n_layers = n_layers
216        self.hidden_size = hidden_size
217        self.hyper_size = hyper_size

为每层创建单元。请注意,只有第一层直接获得输入。其余图层从下面的图层获取输入

221        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
222                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
223                                    range(n_layers - 1)])
  • x 有形状[n_steps, batch_size, input_size]
  • state 是的元组有形状[batch_size, hidden_size]形状[batch_size, hyper_size]
225    def forward(self, x: torch.Tensor,
226                state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
233        n_steps, batch_size = x.shape[:2]

使用零初始化状态如果None

236        if state is None:
237            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
238            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
240            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]

242        else:
243            (h, c, h_hat, c_hat) = state

反向堆叠张量以获得每层的状态

📝 你可以只使用张量本身,但这更容易调试

247            h, c = list(torch.unbind(h)), list(torch.unbind(c))
248            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))

在每一步收集最后一层的输出

251        out = []
252        for t in range(n_steps):

第一层的输入是输入本身

254            inp = x[t]

循环穿过图层

256            for layer in range(self.n_layers):

获取图层的状态

258                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
259                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])

下一层的输入是该图层的状态

261                inp = h[layer]

收集最后一层的输出

263            out.append(h[-1])

堆叠输出和状态

266        out = torch.stack(out)
267        h = torch.stack(h)
268        c = torch.stack(c)
269        h_hat = torch.stack(h_hat)
270        c_hat = torch.stack(c_hat)

273        return out, (h, c, h_hat, c_hat)