""" PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD. Architecture: Top-down FPN + bottom-up path aggregation. - Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone - Output: P3, P4, P5 at same strides with fused multi-scale features - All output channels unified to `out_channels` Key design (from SCRFD paper): - Lightweight PAFPN with configurable channel width - Group Normalization (stable with small batch sizes, per TinaFace finding) - NAS-searched channel width varies by model tier """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple class ConvGNReLU(nn.Module): """Conv + GroupNorm + ReLU.""" def __init__(self, in_ch: int, out_ch: int, kernel: int = 3, stride: int = 1, padding: int = 1, groups: int = 1, num_gn_groups: int = 16, use_relu: bool = True): super().__init__() # Ensure num_gn_groups divides out_ch gn_groups = min(num_gn_groups, out_ch) while out_ch % gn_groups != 0: gn_groups -= 1 self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=False) self.gn = nn.GroupNorm(gn_groups, out_ch) self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.relu(self.gn(self.conv(x))) class PAFPN(nn.Module): """ Path Aggregation Feature Pyramid Network. Flow: 1. Lateral connections: 1×1 conv to unify channels 2. Top-down: upsample + add (FPN) 3. Bottom-up: downsample + add (PAN) 4. Output convs: 3×3 conv for anti-aliasing """ def __init__(self, in_channels: List[int], out_channels: int = 64, num_extra_convs: int = 0, use_gn: bool = True): super().__init__() self.num_levels = len(in_channels) self.out_channels = out_channels # Lateral connections (1×1 conv to unify channels) self.lateral_convs = nn.ModuleList() for in_ch in in_channels: self.lateral_convs.append( ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn else nn.Sequential( nn.Conv2d(in_ch, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) ) # Top-down output convs (anti-aliasing after upsample+add) self.td_convs = nn.ModuleList() for _ in range(self.num_levels): self.td_convs.append( ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn else nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) ) # Bottom-up downsample convs (stride-2) self.bu_convs = nn.ModuleList() for _ in range(self.num_levels - 1): self.bu_convs.append( ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn else nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) ) # Bottom-up output convs self.bu_out_convs = nn.ModuleList() for _ in range(self.num_levels): self.bu_out_convs.append( ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn else nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """ Args: inputs: (C3, C4, C5) feature maps from backbone Returns: (P3, P4, P5) fused feature maps """ assert len(inputs) == self.num_levels # 1. Lateral connections laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)] # 2. Top-down pathway (FPN) for i in range(self.num_levels - 1, 0, -1): up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:], mode='nearest') laterals[i-1] = laterals[i-1] + up td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)] # 3. Bottom-up pathway (PAN) bu_outs = [td_outs[0]] for i in range(self.num_levels - 1): down = self.bu_convs[i](bu_outs[-1]) bu_outs.append(td_outs[i+1] + down) # 4. Output convs outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels)) return outputs # ──────────────────────── Configuration presets ──────────────────────── NECK_CONFIGS = { 'scrfd_34g': dict(out_channels=64), 'scrfd_10g': dict(out_channels=56), 'scrfd_2.5g': dict(out_channels=40), 'scrfd_0.5g': dict(out_channels=16), } def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN: """Build neck by model name.""" cfg = NECK_CONFIGS.get(name, {}) cfg.update(kwargs) return PAFPN(in_channels, **cfg)