11import matplotlib.pyplot as plt
12import numpy as np
13import torch
14import torch.nn as nn
15
16from labml_helpers.module import Module19class LabelSmoothingLoss(Module):20 def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0):
21 super().__init__()
22 self.loss = nn.KLDivLoss(reduction='sum')
23 self.padding_idx = padding_idx
24 self.confidence = 1.0 - smoothing
25 self.smoothing = smoothing
26 self.size = size
27 self.true_dist = None29 def __call__(self, x: torch.Tensor, target: torch.Tensor):
30 assert x.shape[1] == self.size
31 true_dist = x.clone()
32 true_dist.fill_(self.smoothing / (self.size - 2))
33 true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
34 true_dist[:, self.padding_idx] = 0
35 mask = torch.nonzero(target == self.padding_idx, as_tuple=False)
36 if mask.dim() > 0:
37 true_dist.index_fill_(0, mask.squeeze(), 0.0)
38 self.true_dist = true_dist
39 return self.loss(x, true_dist.detach())42def _test_label_smoothing():
43 smooth_loss = LabelSmoothingLoss(5, 0, 0.4)
44 predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
45 [0, 0.2, 0.7, 0.1, 0],
46 [0, 0.2, 0.7, 0.1, 0]], dtype=torch.float)
47 _ = smooth_loss(predict.log(),
48 torch.tensor([2, 1, 0], dtype=torch.long))Show the target distributions expected by the system.
51 plt.imshow(smooth_loss.true_dist)
52 plt.show()
53
54 smooth_loss = LabelSmoothingLoss(5, 0, 0.1)56 def loss_sample(x):
57 d = x + 3 * 1
58 predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d],
59 ], dtype=torch.float)print(predict)
61 return smooth_loss(predict2.log(),
62 torch.tensor([1], dtype=torch.long)).item()
63
64 plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)])
65 plt.show()
66
67
68if __name__ == '__main__':
69 _test_label_smoothing()