| """ |
| PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD. |
| |
| Architecture: Top-down FPN + bottom-up path aggregation. |
| - Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone |
| - Output: P3, P4, P5 at same strides with fused multi-scale features |
| - All output channels unified to `out_channels` |
| |
| Key design (from SCRFD paper): |
| - Lightweight PAFPN with configurable channel width |
| - Group Normalization (stable with small batch sizes, per TinaFace finding) |
| - NAS-searched channel width varies by model tier |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Tuple |
|
|
|
|
| class ConvGNReLU(nn.Module): |
| """Conv + GroupNorm + ReLU.""" |
|
|
| def __init__(self, in_ch: int, out_ch: int, kernel: int = 3, |
| stride: int = 1, padding: int = 1, groups: int = 1, |
| num_gn_groups: int = 16, use_relu: bool = True): |
| super().__init__() |
| |
| gn_groups = min(num_gn_groups, out_ch) |
| while out_ch % gn_groups != 0: |
| gn_groups -= 1 |
|
|
| self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding, |
| groups=groups, bias=False) |
| self.gn = nn.GroupNorm(gn_groups, out_ch) |
| self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.relu(self.gn(self.conv(x))) |
|
|
|
|
| class PAFPN(nn.Module): |
| """ |
| Path Aggregation Feature Pyramid Network. |
| |
| Flow: |
| 1. Lateral connections: 1Γ1 conv to unify channels |
| 2. Top-down: upsample + add (FPN) |
| 3. Bottom-up: downsample + add (PAN) |
| 4. Output convs: 3Γ3 conv for anti-aliasing |
| """ |
|
|
| def __init__(self, in_channels: List[int], out_channels: int = 64, |
| num_extra_convs: int = 0, use_gn: bool = True): |
| super().__init__() |
| self.num_levels = len(in_channels) |
| self.out_channels = out_channels |
|
|
| |
| self.lateral_convs = nn.ModuleList() |
| for in_ch in in_channels: |
| self.lateral_convs.append( |
| ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn |
| else nn.Sequential( |
| nn.Conv2d(in_ch, out_channels, 1, bias=False), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| ) |
|
|
| |
| self.td_convs = nn.ModuleList() |
| for _ in range(self.num_levels): |
| self.td_convs.append( |
| ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn |
| else nn.Sequential( |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| ) |
|
|
| |
| self.bu_convs = nn.ModuleList() |
| for _ in range(self.num_levels - 1): |
| self.bu_convs.append( |
| ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn |
| else nn.Sequential( |
| nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| ) |
|
|
| |
| self.bu_out_convs = nn.ModuleList() |
| for _ in range(self.num_levels): |
| self.bu_out_convs.append( |
| ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn |
| else nn.Sequential( |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: |
| """ |
| Args: |
| inputs: (C3, C4, C5) feature maps from backbone |
| Returns: |
| (P3, P4, P5) fused feature maps |
| """ |
| assert len(inputs) == self.num_levels |
|
|
| |
| laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)] |
|
|
| |
| for i in range(self.num_levels - 1, 0, -1): |
| up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:], |
| mode='nearest') |
| laterals[i-1] = laterals[i-1] + up |
|
|
| td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)] |
|
|
| |
| bu_outs = [td_outs[0]] |
| for i in range(self.num_levels - 1): |
| down = self.bu_convs[i](bu_outs[-1]) |
| bu_outs.append(td_outs[i+1] + down) |
|
|
| |
| outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels)) |
| return outputs |
|
|
|
|
| |
|
|
| NECK_CONFIGS = { |
| 'scrfd_34g': dict(out_channels=64), |
| 'scrfd_10g': dict(out_channels=56), |
| 'scrfd_2.5g': dict(out_channels=40), |
| 'scrfd_0.5g': dict(out_channels=16), |
| } |
|
|
|
|
| def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN: |
| """Build neck by model name.""" |
| cfg = NECK_CONFIGS.get(name, {}) |
| cfg.update(kwargs) |
| return PAFPN(in_channels, **cfg) |
|
|