facedet / models /backbone.py
cledouxluma's picture
Upload models/backbone.py with huggingface_hub
f19fabc verified
"""
SCRFD Backbones β€” NAS-searched ResNet-style with computation redistribution.
Key insight from paper: Standard classification backbones over-invest compute in
C5 features (stride 32), which are useless for tiny face detection. SCRFD
redistributes compute toward earlier stages (C2/C3) for stride-8 feature quality.
Configurations (from paper Table 3):
- SCRFD-34GF: stages=[3,12,28,3], widths=[56,88,248,304], groups=[1,1,1,1]
- SCRFD-10GF: stages=[3,10,16,3], widths=[36,64,144,224], groups=[1,1,1,1]
- SCRFD-2.5GF: stages=[2,4,4,3], widths=[24,48,96,160], groups=[1,1,1,1]
- SCRFD-0.5GF: stages=[2,2,4,2], widths=[16,32,64,128], groups=[1,1,1,1]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
import math
class ConvBNReLU(nn.Module):
"""Conv + BatchNorm + ReLU building block."""
def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
stride: int = 1, padding: int = 1, groups: int = 1,
use_relu: bool = True):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.bn(self.conv(x)))
class BasicBlock(nn.Module):
"""ResNet BasicBlock with optional group convolution."""
expansion = 1
def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
groups: int = 1, downsample: Optional[nn.Module] = None):
super().__init__()
self.conv1 = ConvBNReLU(in_ch, out_ch, 3, stride, 1, groups)
self.conv2 = ConvBNReLU(out_ch, out_ch, 3, 1, 1, groups, use_relu=False)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return self.relu(out)
class BottleneckBlock(nn.Module):
"""ResNet Bottleneck with optional group convolution."""
expansion = 4
def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
groups: int = 1, downsample: Optional[nn.Module] = None):
super().__init__()
mid_ch = out_ch # bottleneck width
self.conv1 = ConvBNReLU(in_ch, mid_ch, 1, 1, 0)
self.conv2 = ConvBNReLU(mid_ch, mid_ch, 3, stride, 1, groups)
self.conv3 = ConvBNReLU(mid_ch, out_ch * self.expansion, 1, 1, 0, use_relu=False)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return self.relu(out)
class SCRFDBackbone(nn.Module):
"""
SCRFD backbone with NAS-searched stage depths and widths.
For SCRFD, we use BasicBlock (expansion=1) since the searched widths
already account for channel capacity β€” no need for bottleneck expansion.
Returns feature maps at strides [8, 16, 32] (C3, C4, C5).
"""
def __init__(self, stages: List[int], widths: List[int],
groups: List[int] = None, in_channels: int = 3,
block_type: str = 'basic'):
super().__init__()
assert len(stages) == 4 and len(widths) == 4
if groups is None:
groups = [1, 1, 1, 1]
Block = BasicBlock if block_type == 'basic' else BottleneckBlock
# Stem: stride 2 conv + stride 2 maxpool β†’ effective stride 4
self.stem = nn.Sequential(
ConvBNReLU(in_channels, widths[0], 3, 2, 1),
ConvBNReLU(widths[0], widths[0], 3, 1, 1),
nn.MaxPool2d(3, 2, 1),
)
# Stage 1: stride 1 (output stride = 4)
self.layer1 = self._make_layer(Block, widths[0], widths[0], stages[0],
stride=1, groups=groups[0])
# Stage 2: stride 2 (output stride = 8) β†’ C3
self.layer2 = self._make_layer(Block, widths[0], widths[1], stages[1],
stride=2, groups=groups[1])
# Stage 3: stride 2 (output stride = 16) β†’ C4
self.layer3 = self._make_layer(Block, widths[1], widths[2], stages[2],
stride=2, groups=groups[2])
# Stage 4: stride 2 (output stride = 32) β†’ C5
self.layer4 = self._make_layer(Block, widths[2], widths[3], stages[3],
stride=2, groups=groups[3])
self.out_channels = [widths[1], widths[2], widths[3]]
self._init_weights()
def _make_layer(self, block, in_ch: int, out_ch: int, num_blocks: int,
stride: int = 1, groups: int = 1) -> nn.Sequential:
downsample = None
if stride != 1 or in_ch != out_ch * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(in_ch, out_ch * block.expansion, 1, stride, bias=False),
nn.BatchNorm2d(out_ch * block.expansion),
)
layers = [block(in_ch, out_ch, stride, groups, downsample)]
in_ch = out_ch * block.expansion
for _ in range(1, num_blocks):
layers.append(block(in_ch, out_ch, 1, groups))
return nn.Sequential(*layers)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.stem(x)
c2 = self.layer1(x) # stride 4
c3 = self.layer2(c2) # stride 8
c4 = self.layer3(c3) # stride 16
c5 = self.layer4(c4) # stride 32
return c3, c4, c5
# ──────────────────────── Configuration presets ────────────────────────
BACKBONE_CONFIGS = {
'scrfd_34g': dict(stages=[3, 12, 28, 3], widths=[56, 88, 248, 304]),
'scrfd_10g': dict(stages=[3, 10, 16, 3], widths=[36, 64, 144, 224]),
'scrfd_2.5g': dict(stages=[2, 4, 4, 3], widths=[24, 48, 96, 160]),
'scrfd_0.5g': dict(stages=[2, 2, 4, 2], widths=[16, 32, 64, 128]),
# ResNet variants for comparison
'resnet50': dict(stages=[3, 4, 6, 3], widths=[64, 128, 256, 512], block_type='bottleneck'),
'resnet18': dict(stages=[2, 2, 2, 2], widths=[64, 128, 256, 512]),
}
def build_backbone(name: str, **kwargs) -> SCRFDBackbone:
"""Build a backbone by name."""
if name not in BACKBONE_CONFIGS:
raise ValueError(f"Unknown backbone: {name}. Options: {list(BACKBONE_CONFIGS.keys())}")
cfg = {**BACKBONE_CONFIGS[name], **kwargs}
return SCRFDBackbone(**cfg)