Label Smoothing Loss

11import matplotlib.pyplot as plt
12import numpy as np
13import torch
14import torch.nn as nn
15
16from labml_helpers.module import Module
19class LabelSmoothingLoss(Module):
20    def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0):
21        super().__init__()
22        self.loss = nn.KLDivLoss(reduction='sum')
23        self.padding_idx = padding_idx
24        self.confidence = 1.0 - smoothing
25        self.smoothing = smoothing
26        self.size = size
27        self.true_dist = None
29    def __call__(self, x: torch.Tensor, target: torch.Tensor):
30        assert x.shape[1] == self.size
31        true_dist = x.clone()
32        true_dist.fill_(self.smoothing / (self.size - 2))
33        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
34        true_dist[:, self.padding_idx] = 0
35        mask = torch.nonzero(target == self.padding_idx, as_tuple=False)
36        if mask.dim() > 0:
37            true_dist.index_fill_(0, mask.squeeze(), 0.0)
38        self.true_dist = true_dist
39        return self.loss(x, true_dist.detach())
42def _test_label_smoothing():
43    smooth_loss = LabelSmoothingLoss(5, 0, 0.4)
44    predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
45                            [0, 0.2, 0.7, 0.1, 0],
46                            [0, 0.2, 0.7, 0.1, 0]], dtype=torch.float)
47    _ = smooth_loss(predict.log(),
48                    torch.tensor([2, 1, 0], dtype=torch.long))

Show the target distributions expected by the system.

51    plt.imshow(smooth_loss.true_dist)
52    plt.show()
53
54    smooth_loss = LabelSmoothingLoss(5, 0, 0.1)
56    def loss_sample(x):
57        d = x + 3 * 1
58        predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d],
59                                 ], dtype=torch.float)

print(predict)

61        return smooth_loss(predict2.log(),
62                           torch.tensor([1], dtype=torch.long)).item()
63
64    plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)])
65    plt.show()
66
67
68if __name__ == '__main__':
69    _test_label_smoothing()