| import torch |
| import torch.nn.functional as F |
|
|
| __all__ = ["label_smooth", "CrossEntropyWithSoftTarget", "CrossEntropyWithLabelSmooth"] |
|
|
|
|
| def label_smooth( |
| target: torch.Tensor, n_classes: int, smooth_factor=0.1 |
| ) -> torch.Tensor: |
| |
| batch_size = target.shape[0] |
| target = torch.unsqueeze(target, 1) |
| soft_target = torch.zeros((batch_size, n_classes), device=target.device) |
| soft_target.scatter_(1, target, 1) |
| |
| soft_target = torch.add( |
| soft_target * (1 - smooth_factor), smooth_factor / n_classes |
| ) |
| return soft_target |
|
|
|
|
| class CrossEntropyWithSoftTarget: |
| @staticmethod |
| def get_loss(pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor: |
| return torch.mean( |
| torch.sum(-soft_target * F.log_softmax(pred, dim=-1, _stacklevel=5), 1) |
| ) |
|
|
| def __call__(self, pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor: |
| return self.get_loss(pred, soft_target) |
|
|
|
|
| class CrossEntropyWithLabelSmooth: |
| def __init__(self, smooth_ratio=0.1): |
| super(CrossEntropyWithLabelSmooth, self).__init__() |
| self.smooth_ratio = smooth_ratio |
|
|
| def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| soft_target = label_smooth(target, pred.shape[1], self.smooth_ratio) |
| return CrossEntropyWithSoftTarget.get_loss(pred, soft_target) |
|
|