Fixed Positional Encodings

The positional encoding encodes the position along the sequence into a vector of size d_model .

Where are the feature indexes in the encoding, and is the position.

23import math
24
25import numpy as np
26import torch
27import torch.nn as nn
28
29from labml_helpers.module import Module
32class PositionalEncoding(Module):
33    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
34        super().__init__()
35        self.dropout = nn.Dropout(dropout_prob)
36
37        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
39    def forward(self, x: torch.Tensor):
40        pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
41        x = x + pe
42        x = self.dropout(x)
43        return x
46def get_positional_encoding(d_model: int, max_len: int = 5000):

Empty encodings vectors

48    encodings = torch.zeros(max_len, d_model)

Position indexes

50    position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)

52    two_i = torch.arange(0, d_model, 2, dtype=torch.float32)

54    div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))

56    encodings[:, 0::2] = torch.sin(position * div_term)

58    encodings[:, 1::2] = torch.cos(position * div_term)

Add batch dimension

61    encodings = encodings.unsqueeze(1).requires_grad_(False)
62
63    return encodings
66def _test_positional_encoding():
67    import matplotlib.pyplot as plt
68
69    plt.figure(figsize=(15, 5))
70    pe = get_positional_encoding(20, 100)
71    plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
72    plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
73    plt.title("Positional encoding")
74    plt.show()
75
76
77if __name__ == '__main__':
78    _test_positional_encoding()