File size: 15,751 Bytes
6953619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
"""
SCRFD Full Detector β€” Backbone + Neck + Head + Loss + Post-processing.

This is the main model class that ties together all components and provides:
1. Training forward: returns losses dict
2. Inference forward: returns detections (boxes, scores, landmarks)
3. ONNX-exportable inference path

Model configurations (WiderFace Hard val / GFLOPs / FPS @VGA on V100):
- SCRFD-34GF:  85.2% / 34 GF / ~80 FPS   (flagship quality)
- SCRFD-10GF:  83.1% / 10 GF / ~140 FPS  (balanced)
- SCRFD-2.5GF: 77.9% / 2.5 GF / ~400 FPS (real-time)
- SCRFD-0.5GF: 68.5% / 0.5 GF / ~1000 FPS (mobile/edge)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
import math

from .backbone import SCRFDBackbone, build_backbone
from .neck import PAFPN, build_neck
from .head import SCRFDHead, build_head
from .anchor import AnchorGenerator, ATSSAssigner
from .losses import GFocalLoss, DIoULoss, FocalLoss, LandmarkLoss


class SCRFD(nn.Module):
    """
    Sample and Computation Redistribution Face Detector.

    Complete pipeline: backbone β†’ PAFPN β†’ shared head β†’ anchors β†’ losses/NMS
    """

    def __init__(self,
                 backbone: SCRFDBackbone,
                 neck: PAFPN,
                 head: SCRFDHead,
                 anchor_generator: AnchorGenerator,
                 assigner: ATSSAssigner,
                 strides: List[int] = [8, 16, 32],
                 score_threshold: float = 0.3,
                 nms_threshold: float = 0.4,
                 max_detections: int = 750,
                 use_gfl: bool = True,
                 cls_weight: float = 1.0,
                 reg_weight: float = 2.0,
                 lmk_weight: float = 0.1):
        super().__init__()
        self.backbone = backbone
        self.neck = neck
        self.head = head
        self.anchor_gen = anchor_generator
        self.assigner = assigner
        self.strides = strides
        self.score_threshold = score_threshold
        self.nms_threshold = nms_threshold
        self.max_detections = max_detections
        self.use_gfl = use_gfl

        # Loss functions
        self.cls_loss_fn = GFocalLoss(beta=2.0) if use_gfl else FocalLoss()
        self.reg_loss_fn = DIoULoss()
        self.lmk_loss_fn = LandmarkLoss() if head.use_landmarks else None

        # Loss weights
        self.cls_weight = cls_weight
        self.reg_weight = reg_weight
        self.lmk_weight = lmk_weight

    def forward(self, images: torch.Tensor,
                targets: Optional[List[Dict]] = None) -> Dict:
        """
        Args:
            images: [B, 3, H, W] batch of images (normalized)
            targets: List of dicts with keys:
                'boxes': [M, 4] face boxes (x1, y1, x2, y2)
                'labels': [M] labels (all 1)
                'landmarks': [M, 10] optional landmarks
                When None, runs inference.

        Returns:
            Training: dict of losses
            Inference: list of dicts with 'boxes', 'scores', 'landmarks'
        """
        # Feature extraction
        features = self.backbone(images)
        features = self.neck(features)
        head_out = self.head(features)

        # Generate anchors
        feat_sizes = [(f.shape[2], f.shape[3]) for f in features]
        anchors_per_level = self.anchor_gen.grid_anchors(feat_sizes, images.device)
        num_anchors_per_level = [a.shape[0] for a in anchors_per_level]

        if targets is not None:
            return self._compute_loss(head_out, anchors_per_level,
                                      num_anchors_per_level, targets, images.shape)
        else:
            return self._inference(head_out, anchors_per_level, images.shape)

    def _compute_loss(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
                      num_per_level: List[int], targets: List[Dict],
                      img_shape: Tuple) -> Dict:
        """Compute training losses."""
        device = anchors_per_level[0].device
        batch_size = len(targets)

        # Flatten predictions across levels
        all_cls = []
        all_reg = []
        all_lmk = []
        for i in range(len(self.strides)):
            B, _, H, W = head_out['cls_scores'][i].shape
            cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1)
            reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4)
            all_cls.append(cls)
            all_reg.append(reg)
            if self.head.use_landmarks and 'lmk_preds' in head_out:
                lmk = head_out['lmk_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 10)
                all_lmk.append(lmk)

        all_cls = torch.cat(all_cls, dim=1)  # [B, N, 1]
        all_reg = torch.cat(all_reg, dim=1)  # [B, N, 4]
        all_anchors = torch.cat(anchors_per_level, dim=0)  # [N, 4]

        has_lmk = len(all_lmk) > 0
        if has_lmk:
            all_lmk = torch.cat(all_lmk, dim=1)

        total_cls_loss = torch.tensor(0.0, device=device)
        total_reg_loss = torch.tensor(0.0, device=device)
        total_lmk_loss = torch.tensor(0.0, device=device)
        num_pos = 0

        for b in range(batch_size):
            gt_boxes = targets[b]['boxes']
            gt_labels = targets[b].get('labels',
                                       torch.ones(gt_boxes.shape[0], dtype=torch.long, device=device))

            # ATSS matching
            assigned_labels, assigned_gt_inds = self.assigner.assign(
                all_anchors, gt_boxes, gt_labels, num_per_level
            )

            pos_mask = assigned_labels > 0
            num_pos += pos_mask.sum().item()

            # Classification loss (all anchors)
            if self.use_gfl:
                # GFL: positive target = IoU, negative target = 0
                cls_targets = torch.zeros(all_anchors.shape[0], device=device)
                if pos_mask.any():
                    pos_anchors = all_anchors[pos_mask]
                    pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]
                    pos_ious = self._compute_iou_single(pos_anchors, pos_gt)
                    cls_targets[pos_mask] = pos_ious
                total_cls_loss += self.cls_loss_fn(
                    all_cls[b].squeeze(-1), cls_targets
                )
            else:
                total_cls_loss += self.cls_loss_fn(
                    all_cls[b].squeeze(-1), (assigned_labels > 0).float()
                )

            # Box regression loss (positive anchors only)
            if pos_mask.any():
                pos_reg = all_reg[b][pos_mask]
                pos_anchors = all_anchors[pos_mask]
                pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]

                # Decode predictions to absolute boxes
                pred_boxes = self._decode_boxes(pos_anchors, pos_reg)
                total_reg_loss += self.reg_loss_fn(pred_boxes, pos_gt)

                # Landmark loss
                if self.head.use_landmarks and 'landmarks' in targets[b] and has_lmk:
                    gt_lmk = targets[b]['landmarks']
                    pos_lmk_pred = all_lmk[b][pos_mask]
                    pos_lmk_gt = gt_lmk[assigned_gt_inds[pos_mask]]
                    # Decode landmarks relative to anchors
                    pred_lmk = self._decode_landmarks(pos_anchors, pos_lmk_pred)
                    total_lmk_loss += self.lmk_loss_fn(pred_lmk, pos_lmk_gt)

        num_pos = max(num_pos, 1)
        losses = {
            'cls_loss': self.cls_weight * total_cls_loss / batch_size,
            'reg_loss': self.reg_weight * total_reg_loss / batch_size,
        }
        if self.head.use_landmarks:
            losses['lmk_loss'] = self.lmk_weight * total_lmk_loss / batch_size

        losses['total_loss'] = sum(losses.values())
        losses['num_pos'] = torch.tensor(num_pos, dtype=torch.float, device=device)
        return losses

    def _inference(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
                   img_shape: Tuple) -> List[Dict]:
        """Run inference with NMS."""
        batch_size = head_out['cls_scores'][0].shape[0]
        device = head_out['cls_scores'][0].device

        results = []
        for b in range(batch_size):
            all_boxes = []
            all_scores = []
            all_lmk = []

            for i in range(len(self.strides)):
                cls = head_out['cls_scores'][i][b].permute(1, 2, 0).reshape(-1, 1).sigmoid()
                reg = head_out['bbox_preds'][i][b].permute(1, 2, 0).reshape(-1, 4)
                anchors = anchors_per_level[i]

                # Filter by score threshold
                scores = cls.squeeze(-1)
                keep = scores > self.score_threshold
                if keep.sum() == 0:
                    continue

                scores = scores[keep]
                reg = reg[keep]
                anc = anchors[keep]

                # Decode boxes
                boxes = self._decode_boxes(anc, reg)

                # Clamp to image boundaries
                boxes[:, 0].clamp_(min=0)
                boxes[:, 1].clamp_(min=0)
                boxes[:, 2].clamp_(max=img_shape[3])
                boxes[:, 3].clamp_(max=img_shape[2])

                all_boxes.append(boxes)
                all_scores.append(scores)

                if self.head.use_landmarks and 'lmk_preds' in head_out:
                    lmk = head_out['lmk_preds'][i][b].permute(1, 2, 0).reshape(-1, 10)[keep]
                    lmk_decoded = self._decode_landmarks(anc, lmk)
                    all_lmk.append(lmk_decoded)

            if not all_boxes:
                results.append({
                    'boxes': torch.empty(0, 4, device=device),
                    'scores': torch.empty(0, device=device),
                })
                continue

            all_boxes = torch.cat(all_boxes, dim=0)
            all_scores = torch.cat(all_scores, dim=0)

            # NMS
            keep = self._nms(all_boxes, all_scores, self.nms_threshold)
            keep = keep[:self.max_detections]

            result = {
                'boxes': all_boxes[keep],
                'scores': all_scores[keep],
            }
            if all_lmk:
                all_lmk = torch.cat(all_lmk, dim=0)
                result['landmarks'] = all_lmk[keep]
            results.append(result)

        return results

    def _decode_boxes(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        """Decode box predictions relative to anchors (distance-based)."""
        anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
        anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
        anchor_w = anchors[:, 2] - anchors[:, 0]
        anchor_h = anchors[:, 3] - anchors[:, 1]

        x1 = anchor_cx - pred[:, 0] * anchor_w
        y1 = anchor_cy - pred[:, 1] * anchor_h
        x2 = anchor_cx + pred[:, 2] * anchor_w
        y2 = anchor_cy + pred[:, 3] * anchor_h

        return torch.stack([x1, y1, x2, y2], dim=1)

    def _decode_landmarks(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        """Decode landmark predictions relative to anchors."""
        anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
        anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
        anchor_w = anchors[:, 2] - anchors[:, 0]
        anchor_h = anchors[:, 3] - anchors[:, 1]

        decoded = pred.clone()
        for i in range(5):
            decoded[:, i*2] = anchor_cx + pred[:, i*2] * anchor_w
            decoded[:, i*2+1] = anchor_cy + pred[:, i*2+1] * anchor_h
        return decoded

    @staticmethod
    def _compute_iou_single(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
        """Compute elementwise IoU between paired boxes. [N,4] Γ— [N,4] β†’ [N]"""
        inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
        inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
        inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
        inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])
        inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)

        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
        union = area1 + area2 - inter
        return inter / (union + 1e-6)

    @staticmethod
    def _nms(boxes: torch.Tensor, scores: torch.Tensor,
             threshold: float) -> torch.Tensor:
        """Non-Maximum Suppression. Returns kept indices."""
        if boxes.shape[0] == 0:
            return torch.empty(0, dtype=torch.long, device=boxes.device)

        # Use torchvision NMS if available, else pure PyTorch
        try:
            from torchvision.ops import nms
            return nms(boxes, scores, threshold)
        except ImportError:
            pass

        # Pure PyTorch NMS fallback
        x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
        areas = (x2 - x1) * (y2 - y1)
        order = scores.argsort(descending=True)
        keep = []

        while order.numel() > 0:
            i = order[0].item()
            keep.append(i)
            if order.numel() == 1:
                break

            xx1 = torch.max(x1[i], x1[order[1:]])
            yy1 = torch.max(y1[i], y1[order[1:]])
            xx2 = torch.min(x2[i], x2[order[1:]])
            yy2 = torch.min(y2[i], y2[order[1:]])
            inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
            iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
            mask = iou <= threshold
            order = order[1:][mask]

        return torch.tensor(keep, dtype=torch.long, device=boxes.device)


# ──────────────────────── Model Builder ────────────────────────

MODEL_CONFIGS = {
    'scrfd_34g': {
        'backbone': 'scrfd_34g',
        'neck_out': 64,
        'head_feat': 64,
        'head_convs': 3,
    },
    'scrfd_10g': {
        'backbone': 'scrfd_10g',
        'neck_out': 56,
        'head_feat': 56,
        'head_convs': 2,
    },
    'scrfd_2.5g': {
        'backbone': 'scrfd_2.5g',
        'neck_out': 40,
        'head_feat': 40,
        'head_convs': 2,
    },
    'scrfd_0.5g': {
        'backbone': 'scrfd_0.5g',
        'neck_out': 16,
        'head_feat': 16,
        'head_convs': 2,
    },
}


def build_detector(name: str, use_landmarks: bool = False,
                   score_threshold: float = 0.3,
                   nms_threshold: float = 0.4,
                   **kwargs) -> SCRFD:
    """
    Build a complete SCRFD detector by name.

    Args:
        name: Model name ('scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g')
        use_landmarks: Enable 5-point landmark prediction
        score_threshold: Detection confidence threshold
        nms_threshold: NMS IoU threshold

    Returns:
        Complete SCRFD detector ready for training or inference
    """
    if name not in MODEL_CONFIGS:
        raise ValueError(f"Unknown model: {name}. Options: {list(MODEL_CONFIGS.keys())}")

    cfg = MODEL_CONFIGS[name]

    backbone = build_backbone(cfg['backbone'])
    neck = PAFPN(backbone.out_channels, out_channels=cfg['neck_out'])
    head = SCRFDHead(
        in_channels=cfg['neck_out'],
        feat_channels=cfg['head_feat'],
        stacked_convs=cfg['head_convs'],
        use_landmarks=use_landmarks,
    )
    anchor_gen = AnchorGenerator()
    assigner = ATSSAssigner(topk=9)

    model = SCRFD(
        backbone=backbone,
        neck=neck,
        head=head,
        anchor_generator=anchor_gen,
        assigner=assigner,
        score_threshold=score_threshold,
        nms_threshold=nms_threshold,
        **kwargs,
    )

    return model