| """ |
| SCRFD Detection Head β shared-weight, multi-task, scale-aware. |
| |
| Design from SCRFD paper: |
| - Weight sharing across pyramid levels (parameter-efficient) |
| - GroupNorm for batch-size independence |
| - Separate cls and reg branches (GFL-style) |
| - Optional landmark branch (RetinaFace-style 5-point) |
| |
| Output per anchor: |
| - Classification: 1 score (face quality score via GFL) |
| - Box regression: 4 values (distance from anchor center to box edges) |
| - Landmarks (optional): 10 values (5 x,y offsets from anchor center) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Tuple, Optional |
| import math |
|
|
|
|
| class SCRFDHead(nn.Module): |
| """ |
| Shared detection head applied to each FPN level. |
| |
| Args: |
| in_channels: Input channels from neck |
| num_classes: Number of classes (1 for face detection) |
| num_anchors: Anchors per spatial location per level |
| feat_channels: Hidden channel width in head convolutions |
| stacked_convs: Number of stacked 3Γ3 convs in each branch |
| use_gn: Use GroupNorm (vs BatchNorm) |
| use_landmarks: Enable 5-point landmark regression branch |
| """ |
|
|
| def __init__(self, |
| in_channels: int = 64, |
| num_classes: int = 1, |
| num_anchors: int = 2, |
| feat_channels: int = 64, |
| stacked_convs: int = 2, |
| use_gn: bool = True, |
| use_landmarks: bool = False): |
| super().__init__() |
| self.num_classes = num_classes |
| self.num_anchors = num_anchors |
| self.use_landmarks = use_landmarks |
|
|
| |
| cls_convs = [] |
| for i in range(stacked_convs): |
| ch_in = in_channels if i == 0 else feat_channels |
| cls_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) |
| if use_gn: |
| gn_groups = min(16, feat_channels) |
| while feat_channels % gn_groups != 0: |
| gn_groups -= 1 |
| cls_convs.append(nn.GroupNorm(gn_groups, feat_channels)) |
| else: |
| cls_convs.append(nn.BatchNorm2d(feat_channels)) |
| cls_convs.append(nn.ReLU(inplace=True)) |
| self.cls_convs = nn.Sequential(*cls_convs) |
| self.cls_out = nn.Conv2d(feat_channels, num_anchors * num_classes, 3, 1, 1) |
|
|
| |
| reg_convs = [] |
| for i in range(stacked_convs): |
| ch_in = in_channels if i == 0 else feat_channels |
| reg_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) |
| if use_gn: |
| gn_groups = min(16, feat_channels) |
| while feat_channels % gn_groups != 0: |
| gn_groups -= 1 |
| reg_convs.append(nn.GroupNorm(gn_groups, feat_channels)) |
| else: |
| reg_convs.append(nn.BatchNorm2d(feat_channels)) |
| reg_convs.append(nn.ReLU(inplace=True)) |
| self.reg_convs = nn.Sequential(*reg_convs) |
| self.reg_out = nn.Conv2d(feat_channels, num_anchors * 4, 3, 1, 1) |
|
|
| |
| if use_landmarks: |
| lmk_convs = [] |
| for i in range(stacked_convs): |
| ch_in = in_channels if i == 0 else feat_channels |
| lmk_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False)) |
| if use_gn: |
| gn_groups = min(16, feat_channels) |
| while feat_channels % gn_groups != 0: |
| gn_groups -= 1 |
| lmk_convs.append(nn.GroupNorm(gn_groups, feat_channels)) |
| else: |
| lmk_convs.append(nn.BatchNorm2d(feat_channels)) |
| lmk_convs.append(nn.ReLU(inplace=True)) |
| self.lmk_convs = nn.Sequential(*lmk_convs) |
| self.lmk_out = nn.Conv2d(feat_channels, num_anchors * 10, 3, 1, 1) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.normal_(m.weight, std=0.01) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| |
| |
| prior_prob = 0.01 |
| bias_init = -math.log((1 - prior_prob) / prior_prob) |
| nn.init.constant_(self.cls_out.bias, bias_init) |
|
|
| def forward_single(self, x: torch.Tensor) -> dict: |
| """Forward pass for a single FPN level.""" |
| cls_feat = self.cls_convs(x) |
| cls_score = self.cls_out(cls_feat) |
|
|
| reg_feat = self.reg_convs(x) |
| bbox_pred = self.reg_out(reg_feat) |
|
|
| result = {'cls_score': cls_score, 'bbox_pred': bbox_pred} |
|
|
| if self.use_landmarks: |
| lmk_feat = self.lmk_convs(x) |
| lmk_pred = self.lmk_out(lmk_feat) |
| result['lmk_pred'] = lmk_pred |
|
|
| return result |
|
|
| def forward(self, features: Tuple[torch.Tensor, ...]) -> dict: |
| """ |
| Forward on all FPN levels. |
| |
| Args: |
| features: (P3, P4, P5) from neck |
| |
| Returns: |
| dict with keys 'cls_scores', 'bbox_preds', optionally 'lmk_preds' |
| Each value is a list of tensors, one per level. |
| """ |
| cls_scores = [] |
| bbox_preds = [] |
| lmk_preds = [] |
|
|
| for feat in features: |
| out = self.forward_single(feat) |
| cls_scores.append(out['cls_score']) |
| bbox_preds.append(out['bbox_pred']) |
| if self.use_landmarks: |
| lmk_preds.append(out['lmk_pred']) |
|
|
| result = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds} |
| if self.use_landmarks: |
| result['lmk_preds'] = lmk_preds |
| return result |
|
|
|
|
| |
|
|
| HEAD_CONFIGS = { |
| 'scrfd_34g': dict(feat_channels=64, stacked_convs=3), |
| 'scrfd_10g': dict(feat_channels=56, stacked_convs=2), |
| 'scrfd_2.5g': dict(feat_channels=40, stacked_convs=2), |
| 'scrfd_0.5g': dict(feat_channels=16, stacked_convs=2), |
| } |
|
|
|
|
| def build_head(name: str, in_channels: int, **kwargs) -> SCRFDHead: |
| """Build detection head by model name.""" |
| cfg = HEAD_CONFIGS.get(name, {}) |
| cfg.update(kwargs) |
| return SCRFDHead(in_channels=in_channels, **cfg) |
|
|