""" Loss functions for SCRFD face detection. SCRFD uses: 1. Generalized Focal Loss (GFL/QFL) for classification — jointly represents classification score and localization quality in a single prediction. 2. DIoU Loss for bounding box regression — better gradient signal for non-overlapping boxes and directly minimizes distance between box centers. References: - GFL: "Generalized Focal Loss" (Li et al., 2020) - DIoU: "Distance-IoU Loss" (Zheng et al., 2020) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class GFocalLoss(nn.Module): """ Quality Focal Loss (QFL) — Generalized Focal Loss for classification. Instead of binary {0,1} targets, QFL uses continuous quality scores [0, 1] where the target is the IoU between predicted and GT boxes. This jointly trains classification confidence and localization quality. Loss = -|y - σ|^β * ((1-y)log(1-σ) + y*log(σ)) where y ∈ [0,1] is quality target, σ is predicted score, β is focusing param. """ def __init__(self, beta: float = 2.0, reduction: str = 'mean'): super().__init__() self.beta = beta self.reduction = reduction def forward(self, pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: pred: [N] predicted scores (logits) target: [N] quality targets in [0, 1] weight: [N] optional sample weights """ pred_sigmoid = pred.sigmoid() scale_factor = (pred_sigmoid - target).abs().pow(self.beta) # Binary cross-entropy with continuous targets bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none') loss = scale_factor * bce if weight is not None: loss = loss * weight if self.reduction == 'mean': return loss.sum() / max(weight.sum() if weight is not None else target.gt(0).sum(), 1) elif self.reduction == 'sum': return loss.sum() return loss class FocalLoss(nn.Module): """ Standard Focal Loss for binary classification. FL(p) = -α * (1-p)^γ * log(p) for positive = -(1-α) * p^γ * log(1-p) for negative Used as fallback when QFL is not appropriate. """ def __init__(self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pred_sigmoid = pred.sigmoid() target = target.float() # Focal weights pt = pred_sigmoid * target + (1 - pred_sigmoid) * (1 - target) focal_weight = (1 - pt).pow(self.gamma) alpha_weight = self.alpha * target + (1 - self.alpha) * (1 - target) bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none') loss = alpha_weight * focal_weight * bce if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss class DIoULoss(nn.Module): """ Distance-IoU Loss for bounding box regression. DIoU = IoU - (ρ²(b, b_gt) / c²) where ρ is Euclidean distance between box centers and c is diagonal length of the smallest enclosing box. This provides better gradients for non-overlapping boxes (common with tiny faces) and directly optimizes center alignment. Loss = 1 - DIoU ∈ [0, 2] """ def __init__(self, reduction: str = 'mean'): super().__init__() self.reduction = reduction def forward(self, pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: pred: [N, 4] predicted boxes (x1, y1, x2, y2) target: [N, 4] target boxes (x1, y1, x2, y2) weight: [N] optional per-box weights """ # Intersection inter_x1 = torch.max(pred[:, 0], target[:, 0]) inter_y1 = torch.max(pred[:, 1], target[:, 1]) inter_x2 = torch.min(pred[:, 2], target[:, 2]) inter_y2 = torch.min(pred[:, 3], target[:, 3]) inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) # Union area_pred = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) area_target = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) union = area_pred + area_target - inter iou = inter / (union + 1e-6) # Center distance pred_cx = (pred[:, 0] + pred[:, 2]) / 2 pred_cy = (pred[:, 1] + pred[:, 3]) / 2 target_cx = (target[:, 0] + target[:, 2]) / 2 target_cy = (target[:, 1] + target[:, 3]) / 2 center_dist_sq = (pred_cx - target_cx).pow(2) + (pred_cy - target_cy).pow(2) # Smallest enclosing box diagonal enclose_x1 = torch.min(pred[:, 0], target[:, 0]) enclose_y1 = torch.min(pred[:, 1], target[:, 1]) enclose_x2 = torch.max(pred[:, 2], target[:, 2]) enclose_y2 = torch.max(pred[:, 3], target[:, 3]) enclose_diag_sq = (enclose_x2 - enclose_x1).pow(2) + (enclose_y2 - enclose_y1).pow(2) diou = iou - center_dist_sq / (enclose_diag_sq + 1e-6) loss = 1 - diou if weight is not None: loss = loss * weight if self.reduction == 'mean': return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1) elif self.reduction == 'sum': return loss.sum() return loss class LandmarkLoss(nn.Module): """ Smooth L1 loss for facial landmark regression (optional multi-task head). Used when landmark annotations are available (e.g., RetinaFace 5-point landmarks on WIDER FACE). Auxiliary landmark supervision improves detection AP by ~1% (RetinaFace paper finding). """ def __init__(self, beta: float = 1.0, reduction: str = 'mean'): super().__init__() self.beta = beta self.reduction = reduction def forward(self, pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: pred: [N, 10] predicted landmarks (5 points × 2 coords) target: [N, 10] target landmarks weight: [N] optional mask for visible landmarks """ loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none') loss = loss.sum(dim=1) # Sum over 10 coords per face if weight is not None: loss = loss * weight if self.reduction == 'mean': return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1) return loss.sum()