""" SCRFD Full Detector — Backbone + Neck + Head + Loss + Post-processing. This is the main model class that ties together all components and provides: 1. Training forward: returns losses dict 2. Inference forward: returns detections (boxes, scores, landmarks) 3. ONNX-exportable inference path Model configurations (WiderFace Hard val / GFLOPs / FPS @VGA on V100): - SCRFD-34GF: 85.2% / 34 GF / ~80 FPS (flagship quality) - SCRFD-10GF: 83.1% / 10 GF / ~140 FPS (balanced) - SCRFD-2.5GF: 77.9% / 2.5 GF / ~400 FPS (real-time) - SCRFD-0.5GF: 68.5% / 0.5 GF / ~1000 FPS (mobile/edge) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Dict, Optional import math from .backbone import SCRFDBackbone, build_backbone from .neck import PAFPN, build_neck from .head import SCRFDHead, build_head from .anchor import AnchorGenerator, ATSSAssigner from .losses import GFocalLoss, DIoULoss, FocalLoss, LandmarkLoss class SCRFD(nn.Module): """ Sample and Computation Redistribution Face Detector. Complete pipeline: backbone → PAFPN → shared head → anchors → losses/NMS """ def __init__(self, backbone: SCRFDBackbone, neck: PAFPN, head: SCRFDHead, anchor_generator: AnchorGenerator, assigner: ATSSAssigner, strides: List[int] = [8, 16, 32], score_threshold: float = 0.3, nms_threshold: float = 0.4, max_detections: int = 750, use_gfl: bool = True, cls_weight: float = 1.0, reg_weight: float = 2.0, lmk_weight: float = 0.1): super().__init__() self.backbone = backbone self.neck = neck self.head = head self.anchor_gen = anchor_generator self.assigner = assigner self.strides = strides self.score_threshold = score_threshold self.nms_threshold = nms_threshold self.max_detections = max_detections self.use_gfl = use_gfl # Loss functions self.cls_loss_fn = GFocalLoss(beta=2.0) if use_gfl else FocalLoss() self.reg_loss_fn = DIoULoss() self.lmk_loss_fn = LandmarkLoss() if head.use_landmarks else None # Loss weights self.cls_weight = cls_weight self.reg_weight = reg_weight self.lmk_weight = lmk_weight def forward(self, images: torch.Tensor, targets: Optional[List[Dict]] = None) -> Dict: """ Args: images: [B, 3, H, W] batch of images (normalized) targets: List of dicts with keys: 'boxes': [M, 4] face boxes (x1, y1, x2, y2) 'labels': [M] labels (all 1) 'landmarks': [M, 10] optional landmarks When None, runs inference. Returns: Training: dict of losses Inference: list of dicts with 'boxes', 'scores', 'landmarks' """ # Feature extraction features = self.backbone(images) features = self.neck(features) head_out = self.head(features) # Generate anchors feat_sizes = [(f.shape[2], f.shape[3]) for f in features] anchors_per_level = self.anchor_gen.grid_anchors(feat_sizes, images.device) num_anchors_per_level = [a.shape[0] for a in anchors_per_level] if targets is not None: return self._compute_loss(head_out, anchors_per_level, num_anchors_per_level, targets, images.shape) else: return self._inference(head_out, anchors_per_level, images.shape) def _compute_loss(self, head_out: Dict, anchors_per_level: List[torch.Tensor], num_per_level: List[int], targets: List[Dict], img_shape: Tuple) -> Dict: """Compute training losses.""" device = anchors_per_level[0].device batch_size = len(targets) # Flatten predictions across levels all_cls = [] all_reg = [] all_lmk = [] for i in range(len(self.strides)): B, _, H, W = head_out['cls_scores'][i].shape cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1) reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4) all_cls.append(cls) all_reg.append(reg) if self.head.use_landmarks and 'lmk_preds' in head_out: lmk = head_out['lmk_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 10) all_lmk.append(lmk) all_cls = torch.cat(all_cls, dim=1) # [B, N, 1] all_reg = torch.cat(all_reg, dim=1) # [B, N, 4] all_anchors = torch.cat(anchors_per_level, dim=0) # [N, 4] has_lmk = len(all_lmk) > 0 if has_lmk: all_lmk = torch.cat(all_lmk, dim=1) total_cls_loss = torch.tensor(0.0, device=device) total_reg_loss = torch.tensor(0.0, device=device) total_lmk_loss = torch.tensor(0.0, device=device) num_pos = 0 for b in range(batch_size): gt_boxes = targets[b]['boxes'] gt_labels = targets[b].get('labels', torch.ones(gt_boxes.shape[0], dtype=torch.long, device=device)) # ATSS matching assigned_labels, assigned_gt_inds = self.assigner.assign( all_anchors, gt_boxes, gt_labels, num_per_level ) pos_mask = assigned_labels > 0 num_pos += pos_mask.sum().item() # Classification loss (all anchors) if self.use_gfl: # GFL: positive target = IoU, negative target = 0 cls_targets = torch.zeros(all_anchors.shape[0], device=device) if pos_mask.any(): pos_anchors = all_anchors[pos_mask] pos_gt = gt_boxes[assigned_gt_inds[pos_mask]] pos_ious = self._compute_iou_single(pos_anchors, pos_gt) cls_targets[pos_mask] = pos_ious total_cls_loss += self.cls_loss_fn( all_cls[b].squeeze(-1), cls_targets ) else: total_cls_loss += self.cls_loss_fn( all_cls[b].squeeze(-1), (assigned_labels > 0).float() ) # Box regression loss (positive anchors only) if pos_mask.any(): pos_reg = all_reg[b][pos_mask] pos_anchors = all_anchors[pos_mask] pos_gt = gt_boxes[assigned_gt_inds[pos_mask]] # Decode predictions to absolute boxes pred_boxes = self._decode_boxes(pos_anchors, pos_reg) total_reg_loss += self.reg_loss_fn(pred_boxes, pos_gt) # Landmark loss if self.head.use_landmarks and 'landmarks' in targets[b] and has_lmk: gt_lmk = targets[b]['landmarks'] pos_lmk_pred = all_lmk[b][pos_mask] pos_lmk_gt = gt_lmk[assigned_gt_inds[pos_mask]] # Decode landmarks relative to anchors pred_lmk = self._decode_landmarks(pos_anchors, pos_lmk_pred) total_lmk_loss += self.lmk_loss_fn(pred_lmk, pos_lmk_gt) num_pos = max(num_pos, 1) losses = { 'cls_loss': self.cls_weight * total_cls_loss / batch_size, 'reg_loss': self.reg_weight * total_reg_loss / batch_size, } if self.head.use_landmarks: losses['lmk_loss'] = self.lmk_weight * total_lmk_loss / batch_size losses['total_loss'] = sum(losses.values()) losses['num_pos'] = torch.tensor(num_pos, dtype=torch.float, device=device) return losses def _inference(self, head_out: Dict, anchors_per_level: List[torch.Tensor], img_shape: Tuple) -> List[Dict]: """Run inference with NMS.""" batch_size = head_out['cls_scores'][0].shape[0] device = head_out['cls_scores'][0].device results = [] for b in range(batch_size): all_boxes = [] all_scores = [] all_lmk = [] for i in range(len(self.strides)): cls = head_out['cls_scores'][i][b].permute(1, 2, 0).reshape(-1, 1).sigmoid() reg = head_out['bbox_preds'][i][b].permute(1, 2, 0).reshape(-1, 4) anchors = anchors_per_level[i] # Filter by score threshold scores = cls.squeeze(-1) keep = scores > self.score_threshold if keep.sum() == 0: continue scores = scores[keep] reg = reg[keep] anc = anchors[keep] # Decode boxes boxes = self._decode_boxes(anc, reg) # Clamp to image boundaries boxes[:, 0].clamp_(min=0) boxes[:, 1].clamp_(min=0) boxes[:, 2].clamp_(max=img_shape[3]) boxes[:, 3].clamp_(max=img_shape[2]) all_boxes.append(boxes) all_scores.append(scores) if self.head.use_landmarks and 'lmk_preds' in head_out: lmk = head_out['lmk_preds'][i][b].permute(1, 2, 0).reshape(-1, 10)[keep] lmk_decoded = self._decode_landmarks(anc, lmk) all_lmk.append(lmk_decoded) if not all_boxes: results.append({ 'boxes': torch.empty(0, 4, device=device), 'scores': torch.empty(0, device=device), }) continue all_boxes = torch.cat(all_boxes, dim=0) all_scores = torch.cat(all_scores, dim=0) # NMS keep = self._nms(all_boxes, all_scores, self.nms_threshold) keep = keep[:self.max_detections] result = { 'boxes': all_boxes[keep], 'scores': all_scores[keep], } if all_lmk: all_lmk = torch.cat(all_lmk, dim=0) result['landmarks'] = all_lmk[keep] results.append(result) return results def _decode_boxes(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: """Decode box predictions relative to anchors (distance-based).""" anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 anchor_w = anchors[:, 2] - anchors[:, 0] anchor_h = anchors[:, 3] - anchors[:, 1] x1 = anchor_cx - pred[:, 0] * anchor_w y1 = anchor_cy - pred[:, 1] * anchor_h x2 = anchor_cx + pred[:, 2] * anchor_w y2 = anchor_cy + pred[:, 3] * anchor_h return torch.stack([x1, y1, x2, y2], dim=1) def _decode_landmarks(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: """Decode landmark predictions relative to anchors.""" anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 anchor_w = anchors[:, 2] - anchors[:, 0] anchor_h = anchors[:, 3] - anchors[:, 1] decoded = pred.clone() for i in range(5): decoded[:, i*2] = anchor_cx + pred[:, i*2] * anchor_w decoded[:, i*2+1] = anchor_cy + pred[:, i*2+1] * anchor_h return decoded @staticmethod def _compute_iou_single(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: """Compute elementwise IoU between paired boxes. [N,4] × [N,4] → [N]""" inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0]) inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1]) inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2]) inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3]) inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) union = area1 + area2 - inter return inter / (union + 1e-6) @staticmethod def _nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) -> torch.Tensor: """Non-Maximum Suppression. Returns kept indices.""" if boxes.shape[0] == 0: return torch.empty(0, dtype=torch.long, device=boxes.device) # Use torchvision NMS if available, else pure PyTorch try: from torchvision.ops import nms return nms(boxes, scores, threshold) except ImportError: pass # Pure PyTorch NMS fallback x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] areas = (x2 - x1) * (y2 - y1) order = scores.argsort(descending=True) keep = [] while order.numel() > 0: i = order[0].item() keep.append(i) if order.numel() == 1: break xx1 = torch.max(x1[i], x1[order[1:]]) yy1 = torch.max(y1[i], y1[order[1:]]) xx2 = torch.min(x2[i], x2[order[1:]]) yy2 = torch.min(y2[i], y2[order[1:]]) inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0) iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6) mask = iou <= threshold order = order[1:][mask] return torch.tensor(keep, dtype=torch.long, device=boxes.device) # ──────────────────────── Model Builder ──────────────────────── MODEL_CONFIGS = { 'scrfd_34g': { 'backbone': 'scrfd_34g', 'neck_out': 64, 'head_feat': 64, 'head_convs': 3, }, 'scrfd_10g': { 'backbone': 'scrfd_10g', 'neck_out': 56, 'head_feat': 56, 'head_convs': 2, }, 'scrfd_2.5g': { 'backbone': 'scrfd_2.5g', 'neck_out': 40, 'head_feat': 40, 'head_convs': 2, }, 'scrfd_0.5g': { 'backbone': 'scrfd_0.5g', 'neck_out': 16, 'head_feat': 16, 'head_convs': 2, }, } def build_detector(name: str, use_landmarks: bool = False, score_threshold: float = 0.3, nms_threshold: float = 0.4, **kwargs) -> SCRFD: """ Build a complete SCRFD detector by name. Args: name: Model name ('scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g') use_landmarks: Enable 5-point landmark prediction score_threshold: Detection confidence threshold nms_threshold: NMS IoU threshold Returns: Complete SCRFD detector ready for training or inference """ if name not in MODEL_CONFIGS: raise ValueError(f"Unknown model: {name}. Options: {list(MODEL_CONFIGS.keys())}") cfg = MODEL_CONFIGS[name] backbone = build_backbone(cfg['backbone']) neck = PAFPN(backbone.out_channels, out_channels=cfg['neck_out']) head = SCRFDHead( in_channels=cfg['neck_out'], feat_channels=cfg['head_feat'], stacked_convs=cfg['head_convs'], use_landmarks=use_landmarks, ) anchor_gen = AnchorGenerator() assigner = ATSSAssigner(topk=9) model = SCRFD( backbone=backbone, neck=neck, head=head, anchor_generator=anchor_gen, assigner=assigner, score_threshold=score_threshold, nms_threshold=nms_threshold, **kwargs, ) return model