facedet / models /head.py
cledouxluma's picture
Upload models/head.py with huggingface_hub
afaa2cf verified
"""
SCRFD Detection Head β€” shared-weight, multi-task, scale-aware.
Design from SCRFD paper:
- Weight sharing across pyramid levels (parameter-efficient)
- GroupNorm for batch-size independence
- Separate cls and reg branches (GFL-style)
- Optional landmark branch (RetinaFace-style 5-point)
Output per anchor:
- Classification: 1 score (face quality score via GFL)
- Box regression: 4 values (distance from anchor center to box edges)
- Landmarks (optional): 10 values (5 x,y offsets from anchor center)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import math
class SCRFDHead(nn.Module):
"""
Shared detection head applied to each FPN level.
Args:
in_channels: Input channels from neck
num_classes: Number of classes (1 for face detection)
num_anchors: Anchors per spatial location per level
feat_channels: Hidden channel width in head convolutions
stacked_convs: Number of stacked 3Γ—3 convs in each branch
use_gn: Use GroupNorm (vs BatchNorm)
use_landmarks: Enable 5-point landmark regression branch
"""
def __init__(self,
in_channels: int = 64,
num_classes: int = 1,
num_anchors: int = 2,
feat_channels: int = 64,
stacked_convs: int = 2,
use_gn: bool = True,
use_landmarks: bool = False):
super().__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
self.use_landmarks = use_landmarks
# Classification branch
cls_convs = []
for i in range(stacked_convs):
ch_in = in_channels if i == 0 else feat_channels
cls_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
if use_gn:
gn_groups = min(16, feat_channels)
while feat_channels % gn_groups != 0:
gn_groups -= 1
cls_convs.append(nn.GroupNorm(gn_groups, feat_channels))
else:
cls_convs.append(nn.BatchNorm2d(feat_channels))
cls_convs.append(nn.ReLU(inplace=True))
self.cls_convs = nn.Sequential(*cls_convs)
self.cls_out = nn.Conv2d(feat_channels, num_anchors * num_classes, 3, 1, 1)
# Box regression branch
reg_convs = []
for i in range(stacked_convs):
ch_in = in_channels if i == 0 else feat_channels
reg_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
if use_gn:
gn_groups = min(16, feat_channels)
while feat_channels % gn_groups != 0:
gn_groups -= 1
reg_convs.append(nn.GroupNorm(gn_groups, feat_channels))
else:
reg_convs.append(nn.BatchNorm2d(feat_channels))
reg_convs.append(nn.ReLU(inplace=True))
self.reg_convs = nn.Sequential(*reg_convs)
self.reg_out = nn.Conv2d(feat_channels, num_anchors * 4, 3, 1, 1)
# Landmark branch (optional)
if use_landmarks:
lmk_convs = []
for i in range(stacked_convs):
ch_in = in_channels if i == 0 else feat_channels
lmk_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
if use_gn:
gn_groups = min(16, feat_channels)
while feat_channels % gn_groups != 0:
gn_groups -= 1
lmk_convs.append(nn.GroupNorm(gn_groups, feat_channels))
else:
lmk_convs.append(nn.BatchNorm2d(feat_channels))
lmk_convs.append(nn.ReLU(inplace=True))
self.lmk_convs = nn.Sequential(*lmk_convs)
self.lmk_out = nn.Conv2d(feat_channels, num_anchors * 10, 3, 1, 1)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Initialize cls bias for focal loss (prevents initial instability)
# Prior probability = 0.01
prior_prob = 0.01
bias_init = -math.log((1 - prior_prob) / prior_prob)
nn.init.constant_(self.cls_out.bias, bias_init)
def forward_single(self, x: torch.Tensor) -> dict:
"""Forward pass for a single FPN level."""
cls_feat = self.cls_convs(x)
cls_score = self.cls_out(cls_feat) # [B, A*C, H, W]
reg_feat = self.reg_convs(x)
bbox_pred = self.reg_out(reg_feat) # [B, A*4, H, W]
result = {'cls_score': cls_score, 'bbox_pred': bbox_pred}
if self.use_landmarks:
lmk_feat = self.lmk_convs(x)
lmk_pred = self.lmk_out(lmk_feat) # [B, A*10, H, W]
result['lmk_pred'] = lmk_pred
return result
def forward(self, features: Tuple[torch.Tensor, ...]) -> dict:
"""
Forward on all FPN levels.
Args:
features: (P3, P4, P5) from neck
Returns:
dict with keys 'cls_scores', 'bbox_preds', optionally 'lmk_preds'
Each value is a list of tensors, one per level.
"""
cls_scores = []
bbox_preds = []
lmk_preds = []
for feat in features:
out = self.forward_single(feat)
cls_scores.append(out['cls_score'])
bbox_preds.append(out['bbox_pred'])
if self.use_landmarks:
lmk_preds.append(out['lmk_pred'])
result = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds}
if self.use_landmarks:
result['lmk_preds'] = lmk_preds
return result
# ──────────────────────── Configuration presets ────────────────────────
HEAD_CONFIGS = {
'scrfd_34g': dict(feat_channels=64, stacked_convs=3),
'scrfd_10g': dict(feat_channels=56, stacked_convs=2),
'scrfd_2.5g': dict(feat_channels=40, stacked_convs=2),
'scrfd_0.5g': dict(feat_channels=16, stacked_convs=2),
}
def build_head(name: str, in_channels: int, **kwargs) -> SCRFDHead:
"""Build detection head by model name."""
cfg = HEAD_CONFIGS.get(name, {})
cfg.update(kwargs)
return SCRFDHead(in_channels=in_channels, **cfg)