""" SCRFD Detection Head — shared-weight, multi-task, scale-aware. Design from SCRFD paper: - Weight sharing across pyramid levels (parameter-efficient) - GroupNorm for batch-size independence - Separate cls and reg branches (GFL-style) - Optional landmark branch (RetinaFace-style 5-point) Output per anchor: - Classification: 1 score (face quality score via GFL) - Box regression: 4 values (distance from anchor center to box edges) - Landmarks (optional): 10 values (5 x,y offsets from anchor center) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional import math class SCRFDHead(nn.Module): """ Shared detection head applied to each FPN level. Args: in_channels: Input channels from neck num_classes: Number of classes (1 for face detection) num_anchors: Anchors per spatial location per level feat_channels: Hidden channel width in head convolutions stacked_convs: Number of stacked 3×3 convs in each branch use_gn: Use GroupNorm (vs BatchNorm) use_landmarks: Enable 5-point landmark regression branch """ def __init__(self, in_channels: int = 64, num_classes: int = 1, num_anchors: int = 2, feat_channels: int = 64, stacked_convs: int = 2, use_gn: bool = True, use_landmarks: bool = False): super().__init__() self.num_classes = num_classes self.num_anchors = num_anchors self.use_landmarks = use_landmarks # Classification branch cls_convs = [] for i in range(stacked_convs): ch_in = in_channels if i == 0 else feat_channels cls_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) if use_gn: gn_groups = min(16, feat_channels) while feat_channels % gn_groups != 0: gn_groups -= 1 cls_convs.append(nn.GroupNorm(gn_groups, feat_channels)) else: cls_convs.append(nn.BatchNorm2d(feat_channels)) cls_convs.append(nn.ReLU(inplace=True)) self.cls_convs = nn.Sequential(*cls_convs) self.cls_out = nn.Conv2d(feat_channels, num_anchors * num_classes, 3, 1, 1) # Box regression branch reg_convs = [] for i in range(stacked_convs): ch_in = in_channels if i == 0 else feat_channels reg_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) if use_gn: gn_groups = min(16, feat_channels) while feat_channels % gn_groups != 0: gn_groups -= 1 reg_convs.append(nn.GroupNorm(gn_groups, feat_channels)) else: reg_convs.append(nn.BatchNorm2d(feat_channels)) reg_convs.append(nn.ReLU(inplace=True)) self.reg_convs = nn.Sequential(*reg_convs) self.reg_out = nn.Conv2d(feat_channels, num_anchors * 4, 3, 1, 1) # Landmark branch (optional) if use_landmarks: lmk_convs = [] for i in range(stacked_convs): ch_in = in_channels if i == 0 else feat_channels lmk_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) if use_gn: gn_groups = min(16, feat_channels) while feat_channels % gn_groups != 0: gn_groups -= 1 lmk_convs.append(nn.GroupNorm(gn_groups, feat_channels)) else: lmk_convs.append(nn.BatchNorm2d(feat_channels)) lmk_convs.append(nn.ReLU(inplace=True)) self.lmk_convs = nn.Sequential(*lmk_convs) self.lmk_out = nn.Conv2d(feat_channels, num_anchors * 10, 3, 1, 1) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Initialize cls bias for focal loss (prevents initial instability) # Prior probability = 0.01 prior_prob = 0.01 bias_init = -math.log((1 - prior_prob) / prior_prob) nn.init.constant_(self.cls_out.bias, bias_init) def forward_single(self, x: torch.Tensor) -> dict: """Forward pass for a single FPN level.""" cls_feat = self.cls_convs(x) cls_score = self.cls_out(cls_feat) # [B, A*C, H, W] reg_feat = self.reg_convs(x) bbox_pred = self.reg_out(reg_feat) # [B, A*4, H, W] result = {'cls_score': cls_score, 'bbox_pred': bbox_pred} if self.use_landmarks: lmk_feat = self.lmk_convs(x) lmk_pred = self.lmk_out(lmk_feat) # [B, A*10, H, W] result['lmk_pred'] = lmk_pred return result def forward(self, features: Tuple[torch.Tensor, ...]) -> dict: """ Forward on all FPN levels. Args: features: (P3, P4, P5) from neck Returns: dict with keys 'cls_scores', 'bbox_preds', optionally 'lmk_preds' Each value is a list of tensors, one per level. """ cls_scores = [] bbox_preds = [] lmk_preds = [] for feat in features: out = self.forward_single(feat) cls_scores.append(out['cls_score']) bbox_preds.append(out['bbox_pred']) if self.use_landmarks: lmk_preds.append(out['lmk_pred']) result = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds} if self.use_landmarks: result['lmk_preds'] = lmk_preds return result # ──────────────────────── Configuration presets ──────────────────────── HEAD_CONFIGS = { 'scrfd_34g': dict(feat_channels=64, stacked_convs=3), 'scrfd_10g': dict(feat_channels=56, stacked_convs=2), 'scrfd_2.5g': dict(feat_channels=40, stacked_convs=2), 'scrfd_0.5g': dict(feat_channels=16, stacked_convs=2), } def build_head(name: str, in_channels: int, **kwargs) -> SCRFDHead: """Build detection head by model name.""" cfg = HEAD_CONFIGS.get(name, {}) cfg.update(kwargs) return SCRFDHead(in_channels=in_channels, **cfg)