Spaces:
Runtime error
Runtime error
| """Loss functions for binary change detection. | |
| Provides BCEDiceLoss (default) and FocalLoss, both operating on raw logits. | |
| A factory function ``get_loss`` reads the project config and returns the | |
| selected loss module. | |
| """ | |
| from typing import Any, Dict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BCEDiceLoss(nn.Module): | |
| """Combined Binary Cross-Entropy and Dice Loss. | |
| Both components operate on raw logits — sigmoid is applied internally so | |
| the caller should **not** pre-apply it. | |
| Args: | |
| bce_weight: Scalar weight for the BCE component. | |
| dice_weight: Scalar weight for the Dice component. | |
| smooth: Smoothing constant for Dice to avoid division by zero. | |
| """ | |
| def __init__( | |
| self, | |
| bce_weight: float = 0.5, | |
| dice_weight: float = 0.5, | |
| smooth: float = 1.0, | |
| ) -> None: | |
| super().__init__() | |
| self.bce_weight = bce_weight | |
| self.dice_weight = dice_weight | |
| self.smooth = smooth | |
| def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| """Compute the combined BCE + Dice loss. | |
| Args: | |
| logits: Raw model output of shape ``[B, 1, H, W]``. | |
| targets: Binary ground-truth masks of shape ``[B, 1, H, W]`` | |
| with values in {0, 1}. | |
| Returns: | |
| Scalar loss tensor on the same device as the inputs. | |
| """ | |
| # --- BCE component (numerically stable, operates on logits) --- | |
| bce_loss = F.binary_cross_entropy_with_logits(logits, targets) | |
| # --- Dice component --- | |
| probs = torch.sigmoid(logits) | |
| # Flatten spatial dims per sample for stable dice computation | |
| probs_flat = probs.view(probs.size(0), -1) | |
| targets_flat = targets.view(targets.size(0), -1) | |
| intersection = (probs_flat * targets_flat).sum(dim=1) | |
| union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1) | |
| dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth) | |
| dice_loss = 1.0 - dice_score.mean() | |
| return self.bce_weight * bce_loss + self.dice_weight * dice_loss | |
| class FocalLoss(nn.Module): | |
| """Focal Loss for addressing class imbalance in change detection. | |
| Down-weights well-classified (easy) pixels so the model focuses on hard | |
| examples near the decision boundary. Operates on raw logits. | |
| Args: | |
| alpha: Balancing factor for the positive class (1 − alpha for negative). | |
| gamma: Focusing exponent — higher values down-weight easy examples more. | |
| """ | |
| def __init__(self, alpha: float = 0.25, gamma: float = 2.0) -> None: | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| """Compute focal loss. | |
| Args: | |
| logits: Raw model output of shape ``[B, 1, H, W]``. | |
| targets: Binary ground-truth masks of shape ``[B, 1, H, W]`` | |
| with values in {0, 1}. | |
| Returns: | |
| Scalar loss tensor on the same device as the inputs. | |
| """ | |
| # Per-pixel BCE (unreduced) | |
| bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") | |
| probs = torch.sigmoid(logits) | |
| # p_t = probability of the true class | |
| p_t = probs * targets + (1.0 - probs) * (1.0 - targets) | |
| # alpha_t = alpha for positives, (1-alpha) for negatives | |
| alpha_t = self.alpha * targets + (1.0 - self.alpha) * (1.0 - targets) | |
| focal_weight = alpha_t * (1.0 - p_t) ** self.gamma | |
| return (focal_weight * bce).mean() | |
| def get_loss(config: Dict[str, Any]) -> nn.Module: | |
| """Factory function — instantiate a loss module from the project config. | |
| Reads ``config["loss"]["name"]`` to select the loss type and extracts | |
| the matching sub-key for constructor arguments. | |
| Args: | |
| config: Full project config dict (as loaded from ``config.yaml``). | |
| Returns: | |
| An ``nn.Module`` loss function ready for ``loss(logits, targets)``. | |
| Raises: | |
| ValueError: If the requested loss name is not recognised. | |
| """ | |
| loss_cfg = config.get("loss", {}) | |
| name = loss_cfg.get("name", "bce_dice") | |
| if name == "bce_dice": | |
| params = loss_cfg.get("bce_dice", {}) | |
| return BCEDiceLoss( | |
| bce_weight=params.get("bce_weight", 0.5), | |
| dice_weight=params.get("dice_weight", 0.5), | |
| ) | |
| elif name == "focal": | |
| params = loss_cfg.get("focal", {}) | |
| return FocalLoss( | |
| alpha=params.get("alpha", 0.25), | |
| gamma=params.get("gamma", 2.0), | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unknown loss '{name}'. Choose from: bce_dice, focal" | |
| ) | |