11import matplotlib.pyplot as plt
12import numpy as np
13
14import torch
15from torch import nn18class LabelSmoothingLoss(nn.Module):19 def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0):
20 super().__init__()
21 self.loss = nn.KLDivLoss(reduction='sum')
22 self.padding_idx = padding_idx
23 self.confidence = 1.0 - smoothing
24 self.smoothing = smoothing
25 self.size = size
26 self.true_dist = None28 def forward(self, x: torch.Tensor, target: torch.Tensor):
29 assert x.shape[1] == self.size
30 true_dist = x.clone()
31 true_dist.fill_(self.smoothing / (self.size - 2))
32 true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
33 true_dist[:, self.padding_idx] = 0
34 mask = torch.nonzero(target == self.padding_idx, as_tuple=False)
35 if mask.dim() > 0:
36 true_dist.index_fill_(0, mask.squeeze(), 0.0)
37 self.true_dist = true_dist
38 return self.loss(x, true_dist.detach())41def _test_label_smoothing():
42 smooth_loss = LabelSmoothingLoss(5, 0, 0.4)
43 predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
44 [0, 0.2, 0.7, 0.1, 0],
45 [0, 0.2, 0.7, 0.1, 0]], dtype=torch.float)
46 _ = smooth_loss(predict.log(),
47 torch.tensor([2, 1, 0], dtype=torch.long))Show the target distributions expected by the system.
50 plt.imshow(smooth_loss.true_dist)
51 plt.show()
52
53 smooth_loss = LabelSmoothingLoss(5, 0, 0.1)55 def loss_sample(x):
56 d = x + 3 * 1
57 predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d],
58 ], dtype=torch.float)print(predict)
60 return smooth_loss(predict2.log(),
61 torch.tensor([1], dtype=torch.long)).item()
62
63 plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)])
64 plt.show()
65
66
67if __name__ == '__main__':
68 _test_label_smoothing()