facedet / models /detector.py
cledouxluma's picture
Upload models/detector.py with huggingface_hub
6953619 verified
"""
SCRFD Full Detector β€” Backbone + Neck + Head + Loss + Post-processing.
This is the main model class that ties together all components and provides:
1. Training forward: returns losses dict
2. Inference forward: returns detections (boxes, scores, landmarks)
3. ONNX-exportable inference path
Model configurations (WiderFace Hard val / GFLOPs / FPS @VGA on V100):
- SCRFD-34GF: 85.2% / 34 GF / ~80 FPS (flagship quality)
- SCRFD-10GF: 83.1% / 10 GF / ~140 FPS (balanced)
- SCRFD-2.5GF: 77.9% / 2.5 GF / ~400 FPS (real-time)
- SCRFD-0.5GF: 68.5% / 0.5 GF / ~1000 FPS (mobile/edge)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
import math
from .backbone import SCRFDBackbone, build_backbone
from .neck import PAFPN, build_neck
from .head import SCRFDHead, build_head
from .anchor import AnchorGenerator, ATSSAssigner
from .losses import GFocalLoss, DIoULoss, FocalLoss, LandmarkLoss
class SCRFD(nn.Module):
"""
Sample and Computation Redistribution Face Detector.
Complete pipeline: backbone β†’ PAFPN β†’ shared head β†’ anchors β†’ losses/NMS
"""
def __init__(self,
backbone: SCRFDBackbone,
neck: PAFPN,
head: SCRFDHead,
anchor_generator: AnchorGenerator,
assigner: ATSSAssigner,
strides: List[int] = [8, 16, 32],
score_threshold: float = 0.3,
nms_threshold: float = 0.4,
max_detections: int = 750,
use_gfl: bool = True,
cls_weight: float = 1.0,
reg_weight: float = 2.0,
lmk_weight: float = 0.1):
super().__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.anchor_gen = anchor_generator
self.assigner = assigner
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.max_detections = max_detections
self.use_gfl = use_gfl
# Loss functions
self.cls_loss_fn = GFocalLoss(beta=2.0) if use_gfl else FocalLoss()
self.reg_loss_fn = DIoULoss()
self.lmk_loss_fn = LandmarkLoss() if head.use_landmarks else None
# Loss weights
self.cls_weight = cls_weight
self.reg_weight = reg_weight
self.lmk_weight = lmk_weight
def forward(self, images: torch.Tensor,
targets: Optional[List[Dict]] = None) -> Dict:
"""
Args:
images: [B, 3, H, W] batch of images (normalized)
targets: List of dicts with keys:
'boxes': [M, 4] face boxes (x1, y1, x2, y2)
'labels': [M] labels (all 1)
'landmarks': [M, 10] optional landmarks
When None, runs inference.
Returns:
Training: dict of losses
Inference: list of dicts with 'boxes', 'scores', 'landmarks'
"""
# Feature extraction
features = self.backbone(images)
features = self.neck(features)
head_out = self.head(features)
# Generate anchors
feat_sizes = [(f.shape[2], f.shape[3]) for f in features]
anchors_per_level = self.anchor_gen.grid_anchors(feat_sizes, images.device)
num_anchors_per_level = [a.shape[0] for a in anchors_per_level]
if targets is not None:
return self._compute_loss(head_out, anchors_per_level,
num_anchors_per_level, targets, images.shape)
else:
return self._inference(head_out, anchors_per_level, images.shape)
def _compute_loss(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
num_per_level: List[int], targets: List[Dict],
img_shape: Tuple) -> Dict:
"""Compute training losses."""
device = anchors_per_level[0].device
batch_size = len(targets)
# Flatten predictions across levels
all_cls = []
all_reg = []
all_lmk = []
for i in range(len(self.strides)):
B, _, H, W = head_out['cls_scores'][i].shape
cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1)
reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4)
all_cls.append(cls)
all_reg.append(reg)
if self.head.use_landmarks and 'lmk_preds' in head_out:
lmk = head_out['lmk_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 10)
all_lmk.append(lmk)
all_cls = torch.cat(all_cls, dim=1) # [B, N, 1]
all_reg = torch.cat(all_reg, dim=1) # [B, N, 4]
all_anchors = torch.cat(anchors_per_level, dim=0) # [N, 4]
has_lmk = len(all_lmk) > 0
if has_lmk:
all_lmk = torch.cat(all_lmk, dim=1)
total_cls_loss = torch.tensor(0.0, device=device)
total_reg_loss = torch.tensor(0.0, device=device)
total_lmk_loss = torch.tensor(0.0, device=device)
num_pos = 0
for b in range(batch_size):
gt_boxes = targets[b]['boxes']
gt_labels = targets[b].get('labels',
torch.ones(gt_boxes.shape[0], dtype=torch.long, device=device))
# ATSS matching
assigned_labels, assigned_gt_inds = self.assigner.assign(
all_anchors, gt_boxes, gt_labels, num_per_level
)
pos_mask = assigned_labels > 0
num_pos += pos_mask.sum().item()
# Classification loss (all anchors)
if self.use_gfl:
# GFL: positive target = IoU, negative target = 0
cls_targets = torch.zeros(all_anchors.shape[0], device=device)
if pos_mask.any():
pos_anchors = all_anchors[pos_mask]
pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]
pos_ious = self._compute_iou_single(pos_anchors, pos_gt)
cls_targets[pos_mask] = pos_ious
total_cls_loss += self.cls_loss_fn(
all_cls[b].squeeze(-1), cls_targets
)
else:
total_cls_loss += self.cls_loss_fn(
all_cls[b].squeeze(-1), (assigned_labels > 0).float()
)
# Box regression loss (positive anchors only)
if pos_mask.any():
pos_reg = all_reg[b][pos_mask]
pos_anchors = all_anchors[pos_mask]
pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]
# Decode predictions to absolute boxes
pred_boxes = self._decode_boxes(pos_anchors, pos_reg)
total_reg_loss += self.reg_loss_fn(pred_boxes, pos_gt)
# Landmark loss
if self.head.use_landmarks and 'landmarks' in targets[b] and has_lmk:
gt_lmk = targets[b]['landmarks']
pos_lmk_pred = all_lmk[b][pos_mask]
pos_lmk_gt = gt_lmk[assigned_gt_inds[pos_mask]]
# Decode landmarks relative to anchors
pred_lmk = self._decode_landmarks(pos_anchors, pos_lmk_pred)
total_lmk_loss += self.lmk_loss_fn(pred_lmk, pos_lmk_gt)
num_pos = max(num_pos, 1)
losses = {
'cls_loss': self.cls_weight * total_cls_loss / batch_size,
'reg_loss': self.reg_weight * total_reg_loss / batch_size,
}
if self.head.use_landmarks:
losses['lmk_loss'] = self.lmk_weight * total_lmk_loss / batch_size
losses['total_loss'] = sum(losses.values())
losses['num_pos'] = torch.tensor(num_pos, dtype=torch.float, device=device)
return losses
def _inference(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
img_shape: Tuple) -> List[Dict]:
"""Run inference with NMS."""
batch_size = head_out['cls_scores'][0].shape[0]
device = head_out['cls_scores'][0].device
results = []
for b in range(batch_size):
all_boxes = []
all_scores = []
all_lmk = []
for i in range(len(self.strides)):
cls = head_out['cls_scores'][i][b].permute(1, 2, 0).reshape(-1, 1).sigmoid()
reg = head_out['bbox_preds'][i][b].permute(1, 2, 0).reshape(-1, 4)
anchors = anchors_per_level[i]
# Filter by score threshold
scores = cls.squeeze(-1)
keep = scores > self.score_threshold
if keep.sum() == 0:
continue
scores = scores[keep]
reg = reg[keep]
anc = anchors[keep]
# Decode boxes
boxes = self._decode_boxes(anc, reg)
# Clamp to image boundaries
boxes[:, 0].clamp_(min=0)
boxes[:, 1].clamp_(min=0)
boxes[:, 2].clamp_(max=img_shape[3])
boxes[:, 3].clamp_(max=img_shape[2])
all_boxes.append(boxes)
all_scores.append(scores)
if self.head.use_landmarks and 'lmk_preds' in head_out:
lmk = head_out['lmk_preds'][i][b].permute(1, 2, 0).reshape(-1, 10)[keep]
lmk_decoded = self._decode_landmarks(anc, lmk)
all_lmk.append(lmk_decoded)
if not all_boxes:
results.append({
'boxes': torch.empty(0, 4, device=device),
'scores': torch.empty(0, device=device),
})
continue
all_boxes = torch.cat(all_boxes, dim=0)
all_scores = torch.cat(all_scores, dim=0)
# NMS
keep = self._nms(all_boxes, all_scores, self.nms_threshold)
keep = keep[:self.max_detections]
result = {
'boxes': all_boxes[keep],
'scores': all_scores[keep],
}
if all_lmk:
all_lmk = torch.cat(all_lmk, dim=0)
result['landmarks'] = all_lmk[keep]
results.append(result)
return results
def _decode_boxes(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
"""Decode box predictions relative to anchors (distance-based)."""
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
anchor_w = anchors[:, 2] - anchors[:, 0]
anchor_h = anchors[:, 3] - anchors[:, 1]
x1 = anchor_cx - pred[:, 0] * anchor_w
y1 = anchor_cy - pred[:, 1] * anchor_h
x2 = anchor_cx + pred[:, 2] * anchor_w
y2 = anchor_cy + pred[:, 3] * anchor_h
return torch.stack([x1, y1, x2, y2], dim=1)
def _decode_landmarks(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
"""Decode landmark predictions relative to anchors."""
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
anchor_w = anchors[:, 2] - anchors[:, 0]
anchor_h = anchors[:, 3] - anchors[:, 1]
decoded = pred.clone()
for i in range(5):
decoded[:, i*2] = anchor_cx + pred[:, i*2] * anchor_w
decoded[:, i*2+1] = anchor_cy + pred[:, i*2+1] * anchor_h
return decoded
@staticmethod
def _compute_iou_single(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""Compute elementwise IoU between paired boxes. [N,4] Γ— [N,4] β†’ [N]"""
inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
union = area1 + area2 - inter
return inter / (union + 1e-6)
@staticmethod
def _nms(boxes: torch.Tensor, scores: torch.Tensor,
threshold: float) -> torch.Tensor:
"""Non-Maximum Suppression. Returns kept indices."""
if boxes.shape[0] == 0:
return torch.empty(0, dtype=torch.long, device=boxes.device)
# Use torchvision NMS if available, else pure PyTorch
try:
from torchvision.ops import nms
return nms(boxes, scores, threshold)
except ImportError:
pass
# Pure PyTorch NMS fallback
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort(descending=True)
keep = []
while order.numel() > 0:
i = order[0].item()
keep.append(i)
if order.numel() == 1:
break
xx1 = torch.max(x1[i], x1[order[1:]])
yy1 = torch.max(y1[i], y1[order[1:]])
xx2 = torch.min(x2[i], x2[order[1:]])
yy2 = torch.min(y2[i], y2[order[1:]])
inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
mask = iou <= threshold
order = order[1:][mask]
return torch.tensor(keep, dtype=torch.long, device=boxes.device)
# ──────────────────────── Model Builder ────────────────────────
MODEL_CONFIGS = {
'scrfd_34g': {
'backbone': 'scrfd_34g',
'neck_out': 64,
'head_feat': 64,
'head_convs': 3,
},
'scrfd_10g': {
'backbone': 'scrfd_10g',
'neck_out': 56,
'head_feat': 56,
'head_convs': 2,
},
'scrfd_2.5g': {
'backbone': 'scrfd_2.5g',
'neck_out': 40,
'head_feat': 40,
'head_convs': 2,
},
'scrfd_0.5g': {
'backbone': 'scrfd_0.5g',
'neck_out': 16,
'head_feat': 16,
'head_convs': 2,
},
}
def build_detector(name: str, use_landmarks: bool = False,
score_threshold: float = 0.3,
nms_threshold: float = 0.4,
**kwargs) -> SCRFD:
"""
Build a complete SCRFD detector by name.
Args:
name: Model name ('scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g')
use_landmarks: Enable 5-point landmark prediction
score_threshold: Detection confidence threshold
nms_threshold: NMS IoU threshold
Returns:
Complete SCRFD detector ready for training or inference
"""
if name not in MODEL_CONFIGS:
raise ValueError(f"Unknown model: {name}. Options: {list(MODEL_CONFIGS.keys())}")
cfg = MODEL_CONFIGS[name]
backbone = build_backbone(cfg['backbone'])
neck = PAFPN(backbone.out_channels, out_channels=cfg['neck_out'])
head = SCRFDHead(
in_channels=cfg['neck_out'],
feat_channels=cfg['head_feat'],
stacked_convs=cfg['head_convs'],
use_landmarks=use_landmarks,
)
anchor_gen = AnchorGenerator()
assigner = ATSSAssigner(topk=9)
model = SCRFD(
backbone=backbone,
neck=neck,
head=head,
anchor_generator=anchor_gen,
assigner=assigner,
score_threshold=score_threshold,
nms_threshold=nms_threshold,
**kwargs,
)
return model