Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | |
| # pyre-strict | |
| import functools | |
| from typing import Callable | |
| import torch | |
| from torch import Tensor | |
| def reduce_loss(loss: Tensor, reduction: str) -> Tensor: | |
| """Reduce the loss tensor based on reduction type. | |
| Args: | |
| loss: Loss tensor to reduce. | |
| reduction: Reduction type ('none', 'mean', or 'sum'). | |
| Returns: | |
| Reduced loss tensor. | |
| """ | |
| match reduction: | |
| case "none": | |
| return loss | |
| case "mean": | |
| return loss.mean() | |
| case "sum": | |
| return loss.sum() | |
| case _: | |
| raise ValueError(f"Unknown reduction type: {reduction}") | |
| def weight_reduce_loss( | |
| loss: Tensor, | |
| weight: Tensor | None = None, | |
| reduction: str = "mean", | |
| avg_factor: float | None = None, | |
| ) -> Tensor: | |
| """Apply weight and reduction to loss tensor. | |
| Args: | |
| loss: Loss tensor. | |
| weight: Optional element-wise weight. | |
| reduction: Reduction type ('none', 'mean', or 'sum'). | |
| avg_factor: Optional averaging factor. | |
| Returns: | |
| Weighted and reduced loss tensor. | |
| """ | |
| # if weight is specified, apply element-wise weight | |
| if weight is not None: | |
| assert weight.dim() == loss.dim() | |
| if weight.dim() > 1: | |
| assert weight.size(1) == 1 or weight.size(1) == loss.size(1) | |
| loss = loss * weight | |
| # if avg_factor is not specified, just reduce the loss | |
| if avg_factor is None: | |
| loss = reduce_loss(loss, reduction) | |
| else: | |
| # if reduction is mean, then average the loss by avg_factor | |
| if reduction == "mean": | |
| # Avoid causing ZeroDivisionError when avg_factor is 0.0, | |
| # i.e., all labels of an image belong to ignore index. | |
| eps = torch.finfo(torch.float32).eps | |
| loss = loss.sum() / (avg_factor + eps) | |
| # if reduction is 'none', then do nothing, otherwise raise an error | |
| elif reduction != "none": | |
| raise ValueError('avg_factor can not be used with reduction="sum"') | |
| return loss | |
| def weighted_loss( | |
| loss_func: Callable[..., Tensor], | |
| ) -> Callable[..., Tensor]: | |
| """Decorator to add weight and reduction support to a loss function. | |
| Args: | |
| loss_func: Loss function to wrap. | |
| Returns: | |
| Wrapped loss function with weight and reduction support. | |
| """ | |
| def wrapper( | |
| pred: Tensor, | |
| target: Tensor, | |
| weight: Tensor | None = None, | |
| reduction: str = "mean", | |
| avg_factor: float | None = None, | |
| **kwargs: object, | |
| ) -> Tensor: | |
| # get element-wise loss | |
| loss = loss_func(pred, target, **kwargs) | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| return wrapper | |