| """ |
| SCRFD Full Detector β Backbone + Neck + Head + Loss + Post-processing. |
| |
| This is the main model class that ties together all components and provides: |
| 1. Training forward: returns losses dict |
| 2. Inference forward: returns detections (boxes, scores, landmarks) |
| 3. ONNX-exportable inference path |
| |
| Model configurations (WiderFace Hard val / GFLOPs / FPS @VGA on V100): |
| - SCRFD-34GF: 85.2% / 34 GF / ~80 FPS (flagship quality) |
| - SCRFD-10GF: 83.1% / 10 GF / ~140 FPS (balanced) |
| - SCRFD-2.5GF: 77.9% / 2.5 GF / ~400 FPS (real-time) |
| - SCRFD-0.5GF: 68.5% / 0.5 GF / ~1000 FPS (mobile/edge) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Tuple, Dict, Optional |
| import math |
|
|
| from .backbone import SCRFDBackbone, build_backbone |
| from .neck import PAFPN, build_neck |
| from .head import SCRFDHead, build_head |
| from .anchor import AnchorGenerator, ATSSAssigner |
| from .losses import GFocalLoss, DIoULoss, FocalLoss, LandmarkLoss |
|
|
|
|
| class SCRFD(nn.Module): |
| """ |
| Sample and Computation Redistribution Face Detector. |
| |
| Complete pipeline: backbone β PAFPN β shared head β anchors β losses/NMS |
| """ |
|
|
| def __init__(self, |
| backbone: SCRFDBackbone, |
| neck: PAFPN, |
| head: SCRFDHead, |
| anchor_generator: AnchorGenerator, |
| assigner: ATSSAssigner, |
| strides: List[int] = [8, 16, 32], |
| score_threshold: float = 0.3, |
| nms_threshold: float = 0.4, |
| max_detections: int = 750, |
| use_gfl: bool = True, |
| cls_weight: float = 1.0, |
| reg_weight: float = 2.0, |
| lmk_weight: float = 0.1): |
| super().__init__() |
| self.backbone = backbone |
| self.neck = neck |
| self.head = head |
| self.anchor_gen = anchor_generator |
| self.assigner = assigner |
| self.strides = strides |
| self.score_threshold = score_threshold |
| self.nms_threshold = nms_threshold |
| self.max_detections = max_detections |
| self.use_gfl = use_gfl |
|
|
| |
| self.cls_loss_fn = GFocalLoss(beta=2.0) if use_gfl else FocalLoss() |
| self.reg_loss_fn = DIoULoss() |
| self.lmk_loss_fn = LandmarkLoss() if head.use_landmarks else None |
|
|
| |
| self.cls_weight = cls_weight |
| self.reg_weight = reg_weight |
| self.lmk_weight = lmk_weight |
|
|
| def forward(self, images: torch.Tensor, |
| targets: Optional[List[Dict]] = None) -> Dict: |
| """ |
| Args: |
| images: [B, 3, H, W] batch of images (normalized) |
| targets: List of dicts with keys: |
| 'boxes': [M, 4] face boxes (x1, y1, x2, y2) |
| 'labels': [M] labels (all 1) |
| 'landmarks': [M, 10] optional landmarks |
| When None, runs inference. |
| |
| Returns: |
| Training: dict of losses |
| Inference: list of dicts with 'boxes', 'scores', 'landmarks' |
| """ |
| |
| features = self.backbone(images) |
| features = self.neck(features) |
| head_out = self.head(features) |
|
|
| |
| feat_sizes = [(f.shape[2], f.shape[3]) for f in features] |
| anchors_per_level = self.anchor_gen.grid_anchors(feat_sizes, images.device) |
| num_anchors_per_level = [a.shape[0] for a in anchors_per_level] |
|
|
| if targets is not None: |
| return self._compute_loss(head_out, anchors_per_level, |
| num_anchors_per_level, targets, images.shape) |
| else: |
| return self._inference(head_out, anchors_per_level, images.shape) |
|
|
| def _compute_loss(self, head_out: Dict, anchors_per_level: List[torch.Tensor], |
| num_per_level: List[int], targets: List[Dict], |
| img_shape: Tuple) -> Dict: |
| """Compute training losses.""" |
| device = anchors_per_level[0].device |
| batch_size = len(targets) |
|
|
| |
| all_cls = [] |
| all_reg = [] |
| all_lmk = [] |
| for i in range(len(self.strides)): |
| B, _, H, W = head_out['cls_scores'][i].shape |
| cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1) |
| reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4) |
| all_cls.append(cls) |
| all_reg.append(reg) |
| if self.head.use_landmarks and 'lmk_preds' in head_out: |
| lmk = head_out['lmk_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 10) |
| all_lmk.append(lmk) |
|
|
| all_cls = torch.cat(all_cls, dim=1) |
| all_reg = torch.cat(all_reg, dim=1) |
| all_anchors = torch.cat(anchors_per_level, dim=0) |
|
|
| has_lmk = len(all_lmk) > 0 |
| if has_lmk: |
| all_lmk = torch.cat(all_lmk, dim=1) |
|
|
| total_cls_loss = torch.tensor(0.0, device=device) |
| total_reg_loss = torch.tensor(0.0, device=device) |
| total_lmk_loss = torch.tensor(0.0, device=device) |
| num_pos = 0 |
|
|
| for b in range(batch_size): |
| gt_boxes = targets[b]['boxes'] |
| gt_labels = targets[b].get('labels', |
| torch.ones(gt_boxes.shape[0], dtype=torch.long, device=device)) |
|
|
| |
| assigned_labels, assigned_gt_inds = self.assigner.assign( |
| all_anchors, gt_boxes, gt_labels, num_per_level |
| ) |
|
|
| pos_mask = assigned_labels > 0 |
| num_pos += pos_mask.sum().item() |
|
|
| |
| if self.use_gfl: |
| |
| cls_targets = torch.zeros(all_anchors.shape[0], device=device) |
| if pos_mask.any(): |
| pos_anchors = all_anchors[pos_mask] |
| pos_gt = gt_boxes[assigned_gt_inds[pos_mask]] |
| pos_ious = self._compute_iou_single(pos_anchors, pos_gt) |
| cls_targets[pos_mask] = pos_ious |
| total_cls_loss += self.cls_loss_fn( |
| all_cls[b].squeeze(-1), cls_targets |
| ) |
| else: |
| total_cls_loss += self.cls_loss_fn( |
| all_cls[b].squeeze(-1), (assigned_labels > 0).float() |
| ) |
|
|
| |
| if pos_mask.any(): |
| pos_reg = all_reg[b][pos_mask] |
| pos_anchors = all_anchors[pos_mask] |
| pos_gt = gt_boxes[assigned_gt_inds[pos_mask]] |
|
|
| |
| pred_boxes = self._decode_boxes(pos_anchors, pos_reg) |
| total_reg_loss += self.reg_loss_fn(pred_boxes, pos_gt) |
|
|
| |
| if self.head.use_landmarks and 'landmarks' in targets[b] and has_lmk: |
| gt_lmk = targets[b]['landmarks'] |
| pos_lmk_pred = all_lmk[b][pos_mask] |
| pos_lmk_gt = gt_lmk[assigned_gt_inds[pos_mask]] |
| |
| pred_lmk = self._decode_landmarks(pos_anchors, pos_lmk_pred) |
| total_lmk_loss += self.lmk_loss_fn(pred_lmk, pos_lmk_gt) |
|
|
| num_pos = max(num_pos, 1) |
| losses = { |
| 'cls_loss': self.cls_weight * total_cls_loss / batch_size, |
| 'reg_loss': self.reg_weight * total_reg_loss / batch_size, |
| } |
| if self.head.use_landmarks: |
| losses['lmk_loss'] = self.lmk_weight * total_lmk_loss / batch_size |
|
|
| losses['total_loss'] = sum(losses.values()) |
| losses['num_pos'] = torch.tensor(num_pos, dtype=torch.float, device=device) |
| return losses |
|
|
| def _inference(self, head_out: Dict, anchors_per_level: List[torch.Tensor], |
| img_shape: Tuple) -> List[Dict]: |
| """Run inference with NMS.""" |
| batch_size = head_out['cls_scores'][0].shape[0] |
| device = head_out['cls_scores'][0].device |
|
|
| results = [] |
| for b in range(batch_size): |
| all_boxes = [] |
| all_scores = [] |
| all_lmk = [] |
|
|
| for i in range(len(self.strides)): |
| cls = head_out['cls_scores'][i][b].permute(1, 2, 0).reshape(-1, 1).sigmoid() |
| reg = head_out['bbox_preds'][i][b].permute(1, 2, 0).reshape(-1, 4) |
| anchors = anchors_per_level[i] |
|
|
| |
| scores = cls.squeeze(-1) |
| keep = scores > self.score_threshold |
| if keep.sum() == 0: |
| continue |
|
|
| scores = scores[keep] |
| reg = reg[keep] |
| anc = anchors[keep] |
|
|
| |
| boxes = self._decode_boxes(anc, reg) |
|
|
| |
| boxes[:, 0].clamp_(min=0) |
| boxes[:, 1].clamp_(min=0) |
| boxes[:, 2].clamp_(max=img_shape[3]) |
| boxes[:, 3].clamp_(max=img_shape[2]) |
|
|
| all_boxes.append(boxes) |
| all_scores.append(scores) |
|
|
| if self.head.use_landmarks and 'lmk_preds' in head_out: |
| lmk = head_out['lmk_preds'][i][b].permute(1, 2, 0).reshape(-1, 10)[keep] |
| lmk_decoded = self._decode_landmarks(anc, lmk) |
| all_lmk.append(lmk_decoded) |
|
|
| if not all_boxes: |
| results.append({ |
| 'boxes': torch.empty(0, 4, device=device), |
| 'scores': torch.empty(0, device=device), |
| }) |
| continue |
|
|
| all_boxes = torch.cat(all_boxes, dim=0) |
| all_scores = torch.cat(all_scores, dim=0) |
|
|
| |
| keep = self._nms(all_boxes, all_scores, self.nms_threshold) |
| keep = keep[:self.max_detections] |
|
|
| result = { |
| 'boxes': all_boxes[keep], |
| 'scores': all_scores[keep], |
| } |
| if all_lmk: |
| all_lmk = torch.cat(all_lmk, dim=0) |
| result['landmarks'] = all_lmk[keep] |
| results.append(result) |
|
|
| return results |
|
|
| def _decode_boxes(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: |
| """Decode box predictions relative to anchors (distance-based).""" |
| anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 |
| anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 |
| anchor_w = anchors[:, 2] - anchors[:, 0] |
| anchor_h = anchors[:, 3] - anchors[:, 1] |
|
|
| x1 = anchor_cx - pred[:, 0] * anchor_w |
| y1 = anchor_cy - pred[:, 1] * anchor_h |
| x2 = anchor_cx + pred[:, 2] * anchor_w |
| y2 = anchor_cy + pred[:, 3] * anchor_h |
|
|
| return torch.stack([x1, y1, x2, y2], dim=1) |
|
|
| def _decode_landmarks(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: |
| """Decode landmark predictions relative to anchors.""" |
| anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 |
| anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 |
| anchor_w = anchors[:, 2] - anchors[:, 0] |
| anchor_h = anchors[:, 3] - anchors[:, 1] |
|
|
| decoded = pred.clone() |
| for i in range(5): |
| decoded[:, i*2] = anchor_cx + pred[:, i*2] * anchor_w |
| decoded[:, i*2+1] = anchor_cy + pred[:, i*2+1] * anchor_h |
| return decoded |
|
|
| @staticmethod |
| def _compute_iou_single(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: |
| """Compute elementwise IoU between paired boxes. [N,4] Γ [N,4] β [N]""" |
| inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0]) |
| inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1]) |
| inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2]) |
| inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3]) |
| inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) |
|
|
| area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) |
| area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) |
| union = area1 + area2 - inter |
| return inter / (union + 1e-6) |
|
|
| @staticmethod |
| def _nms(boxes: torch.Tensor, scores: torch.Tensor, |
| threshold: float) -> torch.Tensor: |
| """Non-Maximum Suppression. Returns kept indices.""" |
| if boxes.shape[0] == 0: |
| return torch.empty(0, dtype=torch.long, device=boxes.device) |
|
|
| |
| try: |
| from torchvision.ops import nms |
| return nms(boxes, scores, threshold) |
| except ImportError: |
| pass |
|
|
| |
| x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] |
| areas = (x2 - x1) * (y2 - y1) |
| order = scores.argsort(descending=True) |
| keep = [] |
|
|
| while order.numel() > 0: |
| i = order[0].item() |
| keep.append(i) |
| if order.numel() == 1: |
| break |
|
|
| xx1 = torch.max(x1[i], x1[order[1:]]) |
| yy1 = torch.max(y1[i], y1[order[1:]]) |
| xx2 = torch.min(x2[i], x2[order[1:]]) |
| yy2 = torch.min(y2[i], y2[order[1:]]) |
| inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0) |
| iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6) |
| mask = iou <= threshold |
| order = order[1:][mask] |
|
|
| return torch.tensor(keep, dtype=torch.long, device=boxes.device) |
|
|
|
|
| |
|
|
| MODEL_CONFIGS = { |
| 'scrfd_34g': { |
| 'backbone': 'scrfd_34g', |
| 'neck_out': 64, |
| 'head_feat': 64, |
| 'head_convs': 3, |
| }, |
| 'scrfd_10g': { |
| 'backbone': 'scrfd_10g', |
| 'neck_out': 56, |
| 'head_feat': 56, |
| 'head_convs': 2, |
| }, |
| 'scrfd_2.5g': { |
| 'backbone': 'scrfd_2.5g', |
| 'neck_out': 40, |
| 'head_feat': 40, |
| 'head_convs': 2, |
| }, |
| 'scrfd_0.5g': { |
| 'backbone': 'scrfd_0.5g', |
| 'neck_out': 16, |
| 'head_feat': 16, |
| 'head_convs': 2, |
| }, |
| } |
|
|
|
|
| def build_detector(name: str, use_landmarks: bool = False, |
| score_threshold: float = 0.3, |
| nms_threshold: float = 0.4, |
| **kwargs) -> SCRFD: |
| """ |
| Build a complete SCRFD detector by name. |
| |
| Args: |
| name: Model name ('scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g') |
| use_landmarks: Enable 5-point landmark prediction |
| score_threshold: Detection confidence threshold |
| nms_threshold: NMS IoU threshold |
| |
| Returns: |
| Complete SCRFD detector ready for training or inference |
| """ |
| if name not in MODEL_CONFIGS: |
| raise ValueError(f"Unknown model: {name}. Options: {list(MODEL_CONFIGS.keys())}") |
|
|
| cfg = MODEL_CONFIGS[name] |
|
|
| backbone = build_backbone(cfg['backbone']) |
| neck = PAFPN(backbone.out_channels, out_channels=cfg['neck_out']) |
| head = SCRFDHead( |
| in_channels=cfg['neck_out'], |
| feat_channels=cfg['head_feat'], |
| stacked_convs=cfg['head_convs'], |
| use_landmarks=use_landmarks, |
| ) |
| anchor_gen = AnchorGenerator() |
| assigner = ATSSAssigner(topk=9) |
|
|
| model = SCRFD( |
| backbone=backbone, |
| neck=neck, |
| head=head, |
| anchor_generator=anchor_gen, |
| assigner=assigner, |
| score_threshold=score_threshold, |
| nms_threshold=nms_threshold, |
| **kwargs, |
| ) |
|
|
| return model |
|
|