File size: 3,046 Bytes
ba23d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# pyre-strict

import functools
from typing import Callable

import torch
from torch import Tensor


def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
    """Reduce the loss tensor based on reduction type.

    Args:
        loss: Loss tensor to reduce.
        reduction: Reduction type ('none', 'mean', or 'sum').

    Returns:
        Reduced loss tensor.
    """
    match reduction:
        case "none":
            return loss
        case "mean":
            return loss.mean()
        case "sum":
            return loss.sum()
        case _:
            raise ValueError(f"Unknown reduction type: {reduction}")


def weight_reduce_loss(
    loss: Tensor,
    weight: Tensor | None = None,
    reduction: str = "mean",
    avg_factor: float | None = None,
) -> Tensor:
    """Apply weight and reduction to loss tensor.

    Args:
        loss: Loss tensor.
        weight: Optional element-wise weight.
        reduction: Reduction type ('none', 'mean', or 'sum').
        avg_factor: Optional averaging factor.

    Returns:
        Weighted and reduced loss tensor.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        assert weight.dim() == loss.dim()
        if weight.dim() > 1:
            assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == "mean":
            # Avoid causing ZeroDivisionError when avg_factor is 0.0,
            # i.e., all labels of an image belong to ignore index.
            eps = torch.finfo(torch.float32).eps
            loss = loss.sum() / (avg_factor + eps)
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != "none":
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss


def weighted_loss(
    loss_func: Callable[..., Tensor],
) -> Callable[..., Tensor]:
    """Decorator to add weight and reduction support to a loss function.

    Args:
        loss_func: Loss function to wrap.

    Returns:
        Wrapped loss function with weight and reduction support.
    """

    @functools.wraps(loss_func)
    def wrapper(
        pred: Tensor,
        target: Tensor,
        weight: Tensor | None = None,
        reduction: str = "mean",
        avg_factor: float | None = None,
        **kwargs: object,
    ) -> Tensor:
        # get element-wise loss
        loss = loss_func(pred, target, **kwargs)
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss

    return wrapper