BaseChange / utils /losses.py
Vedant Jigarbhai Mehta
Initial scaffolding for military base change detection project
b25c087
"""Loss functions for binary change detection.
Provides BCEDiceLoss (default) and FocalLoss, both operating on raw logits.
A factory function ``get_loss`` reads the project config and returns the
selected loss module.
"""
from typing import Any, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
class BCEDiceLoss(nn.Module):
"""Combined Binary Cross-Entropy and Dice Loss.
Both components operate on raw logits — sigmoid is applied internally so
the caller should **not** pre-apply it.
Args:
bce_weight: Scalar weight for the BCE component.
dice_weight: Scalar weight for the Dice component.
smooth: Smoothing constant for Dice to avoid division by zero.
"""
def __init__(
self,
bce_weight: float = 0.5,
dice_weight: float = 0.5,
smooth: float = 1.0,
) -> None:
super().__init__()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
self.smooth = smooth
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Compute the combined BCE + Dice loss.
Args:
logits: Raw model output of shape ``[B, 1, H, W]``.
targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
with values in {0, 1}.
Returns:
Scalar loss tensor on the same device as the inputs.
"""
# --- BCE component (numerically stable, operates on logits) ---
bce_loss = F.binary_cross_entropy_with_logits(logits, targets)
# --- Dice component ---
probs = torch.sigmoid(logits)
# Flatten spatial dims per sample for stable dice computation
probs_flat = probs.view(probs.size(0), -1)
targets_flat = targets.view(targets.size(0), -1)
intersection = (probs_flat * targets_flat).sum(dim=1)
union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)
dice_loss = 1.0 - dice_score.mean()
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
class FocalLoss(nn.Module):
"""Focal Loss for addressing class imbalance in change detection.
Down-weights well-classified (easy) pixels so the model focuses on hard
examples near the decision boundary. Operates on raw logits.
Args:
alpha: Balancing factor for the positive class (1 − alpha for negative).
gamma: Focusing exponent — higher values down-weight easy examples more.
"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0) -> None:
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Compute focal loss.
Args:
logits: Raw model output of shape ``[B, 1, H, W]``.
targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
with values in {0, 1}.
Returns:
Scalar loss tensor on the same device as the inputs.
"""
# Per-pixel BCE (unreduced)
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
probs = torch.sigmoid(logits)
# p_t = probability of the true class
p_t = probs * targets + (1.0 - probs) * (1.0 - targets)
# alpha_t = alpha for positives, (1-alpha) for negatives
alpha_t = self.alpha * targets + (1.0 - self.alpha) * (1.0 - targets)
focal_weight = alpha_t * (1.0 - p_t) ** self.gamma
return (focal_weight * bce).mean()
def get_loss(config: Dict[str, Any]) -> nn.Module:
"""Factory function — instantiate a loss module from the project config.
Reads ``config["loss"]["name"]`` to select the loss type and extracts
the matching sub-key for constructor arguments.
Args:
config: Full project config dict (as loaded from ``config.yaml``).
Returns:
An ``nn.Module`` loss function ready for ``loss(logits, targets)``.
Raises:
ValueError: If the requested loss name is not recognised.
"""
loss_cfg = config.get("loss", {})
name = loss_cfg.get("name", "bce_dice")
if name == "bce_dice":
params = loss_cfg.get("bce_dice", {})
return BCEDiceLoss(
bce_weight=params.get("bce_weight", 0.5),
dice_weight=params.get("dice_weight", 0.5),
)
elif name == "focal":
params = loss_cfg.get("focal", {})
return FocalLoss(
alpha=params.get("alpha", 0.25),
gamma=params.get("gamma", 2.0),
)
else:
raise ValueError(
f"Unknown loss '{name}'. Choose from: bce_dice, focal"
)