facedet / models /neck.py
cledouxluma's picture
Upload models/neck.py with huggingface_hub
a73c9ef verified
"""
PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD.
Architecture: Top-down FPN + bottom-up path aggregation.
- Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone
- Output: P3, P4, P5 at same strides with fused multi-scale features
- All output channels unified to `out_channels`
Key design (from SCRFD paper):
- Lightweight PAFPN with configurable channel width
- Group Normalization (stable with small batch sizes, per TinaFace finding)
- NAS-searched channel width varies by model tier
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
class ConvGNReLU(nn.Module):
"""Conv + GroupNorm + ReLU."""
def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
stride: int = 1, padding: int = 1, groups: int = 1,
num_gn_groups: int = 16, use_relu: bool = True):
super().__init__()
# Ensure num_gn_groups divides out_ch
gn_groups = min(num_gn_groups, out_ch)
while out_ch % gn_groups != 0:
gn_groups -= 1
self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
groups=groups, bias=False)
self.gn = nn.GroupNorm(gn_groups, out_ch)
self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.gn(self.conv(x)))
class PAFPN(nn.Module):
"""
Path Aggregation Feature Pyramid Network.
Flow:
1. Lateral connections: 1Γ—1 conv to unify channels
2. Top-down: upsample + add (FPN)
3. Bottom-up: downsample + add (PAN)
4. Output convs: 3Γ—3 conv for anti-aliasing
"""
def __init__(self, in_channels: List[int], out_channels: int = 64,
num_extra_convs: int = 0, use_gn: bool = True):
super().__init__()
self.num_levels = len(in_channels)
self.out_channels = out_channels
# Lateral connections (1Γ—1 conv to unify channels)
self.lateral_convs = nn.ModuleList()
for in_ch in in_channels:
self.lateral_convs.append(
ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn
else nn.Sequential(
nn.Conv2d(in_ch, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Top-down output convs (anti-aliasing after upsample+add)
self.td_convs = nn.ModuleList()
for _ in range(self.num_levels):
self.td_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Bottom-up downsample convs (stride-2)
self.bu_convs = nn.ModuleList()
for _ in range(self.num_levels - 1):
self.bu_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Bottom-up output convs
self.bu_out_convs = nn.ModuleList()
for _ in range(self.num_levels):
self.bu_out_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""
Args:
inputs: (C3, C4, C5) feature maps from backbone
Returns:
(P3, P4, P5) fused feature maps
"""
assert len(inputs) == self.num_levels
# 1. Lateral connections
laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)]
# 2. Top-down pathway (FPN)
for i in range(self.num_levels - 1, 0, -1):
up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:],
mode='nearest')
laterals[i-1] = laterals[i-1] + up
td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)]
# 3. Bottom-up pathway (PAN)
bu_outs = [td_outs[0]]
for i in range(self.num_levels - 1):
down = self.bu_convs[i](bu_outs[-1])
bu_outs.append(td_outs[i+1] + down)
# 4. Output convs
outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels))
return outputs
# ──────────────────────── Configuration presets ────────────────────────
NECK_CONFIGS = {
'scrfd_34g': dict(out_channels=64),
'scrfd_10g': dict(out_channels=56),
'scrfd_2.5g': dict(out_channels=40),
'scrfd_0.5g': dict(out_channels=16),
}
def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN:
"""Build neck by model name."""
cfg = NECK_CONFIGS.get(name, {})
cfg.update(kwargs)
return PAFPN(in_channels, **cfg)