File size: 6,946 Bytes
2cc83d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Loss functions for SCRFD face detection.

SCRFD uses:
1. Generalized Focal Loss (GFL/QFL) for classification — jointly represents
   classification score and localization quality in a single prediction.
2. DIoU Loss for bounding box regression — better gradient signal for
   non-overlapping boxes and directly minimizes distance between box centers.

References:
- GFL: "Generalized Focal Loss" (Li et al., 2020)
- DIoU: "Distance-IoU Loss" (Zheng et al., 2020)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class GFocalLoss(nn.Module):
    """
    Quality Focal Loss (QFL) — Generalized Focal Loss for classification.

    Instead of binary {0,1} targets, QFL uses continuous quality scores
    [0, 1] where the target is the IoU between predicted and GT boxes.
    This jointly trains classification confidence and localization quality.

    Loss = -|y - σ|^β * ((1-y)log(1-σ) + y*log(σ))

    where y ∈ [0,1] is quality target, σ is predicted score, β is focusing param.
    """

    def __init__(self, beta: float = 2.0, reduction: str = 'mean'):
        super().__init__()
        self.beta = beta
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor,
                weight: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            pred: [N] predicted scores (logits)
            target: [N] quality targets in [0, 1]
            weight: [N] optional sample weights
        """
        pred_sigmoid = pred.sigmoid()
        scale_factor = (pred_sigmoid - target).abs().pow(self.beta)

        # Binary cross-entropy with continuous targets
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        loss = scale_factor * bce

        if weight is not None:
            loss = loss * weight

        if self.reduction == 'mean':
            return loss.sum() / max(weight.sum() if weight is not None else target.gt(0).sum(), 1)
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class FocalLoss(nn.Module):
    """
    Standard Focal Loss for binary classification.

    FL(p) = -α * (1-p)^γ * log(p)    for positive
          = -(1-α) * p^γ * log(1-p)  for negative

    Used as fallback when QFL is not appropriate.
    """

    def __init__(self, alpha: float = 0.25, gamma: float = 2.0,
                 reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_sigmoid = pred.sigmoid()
        target = target.float()

        # Focal weights
        pt = pred_sigmoid * target + (1 - pred_sigmoid) * (1 - target)
        focal_weight = (1 - pt).pow(self.gamma)
        alpha_weight = self.alpha * target + (1 - self.alpha) * (1 - target)

        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        loss = alpha_weight * focal_weight * bce

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class DIoULoss(nn.Module):
    """
    Distance-IoU Loss for bounding box regression.

    DIoU = IoU - (ρ²(b, b_gt) / c²)

    where ρ is Euclidean distance between box centers and c is diagonal
    length of the smallest enclosing box. This provides better gradients
    for non-overlapping boxes (common with tiny faces) and directly
    optimizes center alignment.

    Loss = 1 - DIoU ∈ [0, 2]
    """

    def __init__(self, reduction: str = 'mean'):
        super().__init__()
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor,
                weight: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            pred: [N, 4] predicted boxes (x1, y1, x2, y2)
            target: [N, 4] target boxes (x1, y1, x2, y2)
            weight: [N] optional per-box weights
        """
        # Intersection
        inter_x1 = torch.max(pred[:, 0], target[:, 0])
        inter_y1 = torch.max(pred[:, 1], target[:, 1])
        inter_x2 = torch.min(pred[:, 2], target[:, 2])
        inter_y2 = torch.min(pred[:, 3], target[:, 3])
        inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)

        # Union
        area_pred = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
        area_target = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
        union = area_pred + area_target - inter

        iou = inter / (union + 1e-6)

        # Center distance
        pred_cx = (pred[:, 0] + pred[:, 2]) / 2
        pred_cy = (pred[:, 1] + pred[:, 3]) / 2
        target_cx = (target[:, 0] + target[:, 2]) / 2
        target_cy = (target[:, 1] + target[:, 3]) / 2
        center_dist_sq = (pred_cx - target_cx).pow(2) + (pred_cy - target_cy).pow(2)

        # Smallest enclosing box diagonal
        enclose_x1 = torch.min(pred[:, 0], target[:, 0])
        enclose_y1 = torch.min(pred[:, 1], target[:, 1])
        enclose_x2 = torch.max(pred[:, 2], target[:, 2])
        enclose_y2 = torch.max(pred[:, 3], target[:, 3])
        enclose_diag_sq = (enclose_x2 - enclose_x1).pow(2) + (enclose_y2 - enclose_y1).pow(2)

        diou = iou - center_dist_sq / (enclose_diag_sq + 1e-6)
        loss = 1 - diou

        if weight is not None:
            loss = loss * weight

        if self.reduction == 'mean':
            return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1)
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class LandmarkLoss(nn.Module):
    """
    Smooth L1 loss for facial landmark regression (optional multi-task head).

    Used when landmark annotations are available (e.g., RetinaFace 5-point
    landmarks on WIDER FACE). Auxiliary landmark supervision improves
    detection AP by ~1% (RetinaFace paper finding).
    """

    def __init__(self, beta: float = 1.0, reduction: str = 'mean'):
        super().__init__()
        self.beta = beta
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor,
                weight: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            pred: [N, 10] predicted landmarks (5 points × 2 coords)
            target: [N, 10] target landmarks
            weight: [N] optional mask for visible landmarks
        """
        loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none')
        loss = loss.sum(dim=1)  # Sum over 10 coords per face

        if weight is not None:
            loss = loss * weight

        if self.reduction == 'mean':
            return loss.sum() / max(weight.sum() if weight is not None else loss.shape[0], 1)
        return loss.sum()