```
12from typing import Optional, Tuple
13
14import torch
15from torch import nn
16
17from labml_helpers.module import Module
```

LSTM Cell computes $c$, and $h$. $c$ is like the long-term memory, and $h$ is like the short term memory. We use the input $x$ and $h$ to update the long term memory. In the update, some features of $c$ are cleared with a forget gate $f$, and some features $i$ are added through a gate $g$.

The new short term memory is the $tanh$ of the long-term memory multiplied by the output gate $o$.

Note that the cell doesn't look at long term memory $c$ when doing the update. It only modifies it. Also $c$ never goes through a linear transformation. This is what solves vanishing and exploding gradients.

Here's the update rule.

$c_{t}h_{t} =σ(f_{t})⊙c_{t−1}+σ(i_{t})⊙tanh(g_{t})=σ(o_{t})⊙tanh(c_{t}) $$⊙$ stands for element-wise multiplication.

Intermediate values and gates are computed as linear transformations of the hidden state and input.

$i_{t}f_{t}g_{t}o_{t} =lin_{x}(x_{t})+lin_{h}(h_{t−1})=lin_{x}(x_{t})+lin_{h}(h_{t−1})=lin_{x}(x_{t})+lin_{h}(h_{t−1})=lin_{x}(x_{t})+lin_{h}(h_{t−1}) $`20class LSTMCell(Module):`

```
57 def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False):
58 super().__init__()
```

These are the linear layer to transform the `input`

and `hidden`

vectors. One of them doesn't need a bias since we add the transformations.

This combines $lin_{x}$, $lin_{x}$, $lin_{x}$, and $lin_{x}$ transformations.

`64 self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)`

This combines $lin_{h}$, $lin_{h}$, $lin_{h}$, and $lin_{h}$ transformations.

`66 self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)`

Whether to apply layer normalizations.

Applying layer normalization gives better results. $i$, $f$, $g$ and $o$ embeddings are normalized and $c_{t}$ is normalized in $h_{t}=o_{t}⊙tanh(LN(c_{t}))$

```
73 if layer_norm:
74 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
75 self.layer_norm_c = nn.LayerNorm(hidden_size)
76 else:
77 self.layer_norm = nn.ModuleList([nn.Identity() for _ in range(4)])
78 self.layer_norm_c = nn.Identity()
```

`80 def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):`

We compute the linear transformations for $i_{t}$, $f_{t}$, $g_{t}$ and $o_{t}$ using the same linear layers.

`83 ifgo = self.hidden_lin(h) + self.input_lin(x)`

Each layer produces an output of 4 times the `hidden_size`

and we split them

`85 ifgo = ifgo.chunk(4, dim=-1)`

Apply layer normalization (not in original paper, but gives better results)

`88 ifgo = [self.layer_norm[i](ifgo[i]) for i in range(4)]`

$i_{t},f_{t},g_{t},o_{t}$

`91 i, f, g, o = ifgo`

$c_{t}=σ(f_{t})⊙c_{t−1}+σ(i_{t})⊙tanh(g_{t})$

`94 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)`

$h_{t}=σ(o_{t})⊙tanh(c_{t})$ Optionally, apply layer norm to $c_{t}$

```
98 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
99
100 return h_next, c_next
```

`103class LSTM(Module):`

Create a network of `n_layers`

of LSTM.

`108 def __init__(self, input_size: int, hidden_size: int, n_layers: int):`

```
113 super().__init__()
114 self.n_layers = n_layers
115 self.hidden_size = hidden_size
```

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

```
118 self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
119 [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
```

`x`

has shape `[n_steps, batch_size, input_size]`

and `state`

is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`

.

`121 def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):`

`126 n_steps, batch_size = x.shape[:2]`

Initialize the state if `None`

```
129 if state is None:
130 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
131 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
132 else:
133 (h, c) = state
```

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

`137 h, c = list(torch.unbind(h)), list(torch.unbind(c))`

Array to collect the outputs of the final layer at each time step.

```
140 out = []
141 for t in range(n_steps):
```

Input to the first layer is the input itself

`143 inp = x[t]`

Loop through the layers

`145 for layer in range(self.n_layers):`

Get the state of the layer

`147 h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])`

Input to the next layer is the state of this layer

`149 inp = h[layer]`

Collect the output $h$ of the final layer

`151 out.append(h[-1])`

Stack the outputs and states

```
154 out = torch.stack(out)
155 h = torch.stack(h)
156 c = torch.stack(c)
157
158 return out, (h, c)
```