File size: 6,946 Bytes
2cc83d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
Loss functions for SCRFD face detection.
SCRFD uses:
1. Generalized Focal Loss (GFL/QFL) for classification — jointly represents
classification score and localization quality in a single prediction.
2. DIoU Loss for bounding box regression — better gradient signal for
non-overlapping boxes and directly minimizes distance between box centers.
References:
- GFL: "Generalized Focal Loss" (Li et al., 2020)
- DIoU: "Distance-IoU Loss" (Zheng et al., 2020)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class GFocalLoss(nn.Module):
"""
Quality Focal Loss (QFL) — Generalized Focal Loss for classification.
Instead of binary {0,1} targets, QFL uses continuous quality scores
[0, 1] where the target is the IoU between predicted and GT boxes.
This jointly trains classification confidence and localization quality.
Loss = -|y - σ|^β * ((1-y)log(1-σ) + y*log(σ))
where y ∈ [0,1] is quality target, σ is predicted score, β is focusing param.
"""
def __init__(self, beta: float = 2.0, reduction: str = 'mean'):
super().__init__()
self.beta = beta
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor,
weight: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
pred: [N] predicted scores (logits)
target: [N] quality targets in [0, 1]
weight: [N] optional sample weights
"""
pred_sigmoid = pred.sigmoid()
scale_factor = (pred_sigmoid - target).abs().pow(self.beta)
# Binary cross-entropy with continuous targets
bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
loss = scale_factor * bce
if weight is not None:
loss = loss * weight
if self.reduction == 'mean':
return loss.sum() / max(weight.sum() if weight is not None else target.gt(0).sum(), 1)
elif self.reduction == 'sum':
return loss.sum()
return loss
class FocalLoss(nn.Module):
"""
Standard Focal Loss for binary classification.
FL(p) = -α * (1-p)^γ * log(p) for positive
= -(1-α) * p^γ * log(1-p) for negative
Used as fallback when QFL is not appropriate.
"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0,
reduction: str = 'mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred_sigmoid = pred.sigmoid()
target = target.float()
# Focal weights
pt = pred_sigmoid * target + (1 - pred_sigmoid) * (1 - target)
focal_weight = (1 - pt).pow(self.gamma)
alpha_weight = self.alpha * target + (1 - self.alpha) * (1 - target)
bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
loss = alpha_weight * focal_weight * bce
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss
class DIoULoss(nn.Module):
"""
Distance-IoU Loss for bounding box regression.
DIoU = IoU - (ρ²(b, b_gt) / c²)
where ρ is Euclidean distance between box centers and c is diagonal
length of the smallest enclosing box. This provides better gradients
for non-overlapping boxes (common with tiny faces) and directly
optimizes center alignment.
Loss = 1 - DIoU ∈ [0, 2]
"""
def __init__(self, reduction: str = 'mean'):
super().__init__()
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor,
weight: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
pred: [N, 4] predicted boxes (x1, y1, x2, y2)
target: [N, 4] target boxes (x1, y1, x2, y2)
weight: [N] optional per-box weights
"""
# Intersection
inter_x1 = torch.max(pred[:, 0], target[:, 0])
inter_y1 = torch.max(pred[:, 1], target[:, 1])
inter_x2 = torch.min(pred[:, 2], target[:, 2])
inter_y2 = torch.min(pred[:, 3], target[:, 3])
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
# Union
area_pred = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
area_target = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = area_pred + area_target - inter
iou = inter / (union + 1e-6)
# Center distance
pred_cx = (pred[:, 0] + pred[:, 2]) / 2
pred_cy = (pred[:, 1] + pred[:, 3]) / 2
target_cx = (target[:, 0] + target[:, 2]) / 2
target_cy = (target[:, 1] + target[:, 3]) / 2
center_dist_sq = (pred_cx - target_cx).pow(2) + (pred_cy - target_cy).pow(2)
# Smallest enclosing box diagonal
enclose_x1 = torch.min(pred[:, 0], target[:, 0])
enclose_y1 = torch.min(pred[:, 1], target[:, 1])
enclose_x2 = torch.max(pred[:, 2], target[:, 2])
enclose_y2 = torch.max(pred[:, 3], target[:, 3])
enclose_diag_sq = (enclose_x2 - enclose_x1).pow(2) + (enclose_y2 - enclose_y1).pow(2)
diou = iou - center_dist_sq / (enclose_diag_sq + 1e-6)
loss = 1 - diou
if weight is not None:
loss = loss * weight
if self.reduction == 'mean':
return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1)
elif self.reduction == 'sum':
return loss.sum()
return loss
class LandmarkLoss(nn.Module):
"""
Smooth L1 loss for facial landmark regression (optional multi-task head).
Used when landmark annotations are available (e.g., RetinaFace 5-point
landmarks on WIDER FACE). Auxiliary landmark supervision improves
detection AP by ~1% (RetinaFace paper finding).
"""
def __init__(self, beta: float = 1.0, reduction: str = 'mean'):
super().__init__()
self.beta = beta
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor,
weight: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
pred: [N, 10] predicted landmarks (5 points × 2 coords)
target: [N, 10] target landmarks
weight: [N] optional mask for visible landmarks
"""
loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none')
loss = loss.sum(dim=1) # Sum over 10 coords per face
if weight is not None:
loss = loss * weight
if self.reduction == 'mean':
return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1)
return loss.sum()
|