23import math
24
25import numpy as np
26import torch
27import torch.nn as nn
30class PositionalEncoding(nn.Module):
31 def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
32 super().__init__()
33 self.dropout = nn.Dropout(dropout_prob)
34
35 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
37 def forward(self, x: torch.Tensor):
38 pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
39 x = x + pe
40 x = self.dropout(x)
41 return x
44def get_positional_encoding(d_model: int, max_len: int = 5000):
空编码向量
46 encodings = torch.zeros(max_len, d_model)
位置索引
48 position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
50 two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
52 div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
54 encodings[:, 0::2] = torch.sin(position * div_term)
56 encodings[:, 1::2] = torch.cos(position * div_term)
增加批处理维度
59 encodings = encodings.unsqueeze(1).requires_grad_(False)
60
61 return encodings
64def _test_positional_encoding():
65 import matplotlib.pyplot as plt
66
67 plt.figure(figsize=(15, 5))
68 pe = get_positional_encoding(20, 100)
69 plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
70 plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
71 plt.title("Positional encoding")
72 plt.show()
73
74
75if __name__ == '__main__':
76 _test_positional_encoding()