File size: 5,922 Bytes
a73c9ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """
PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD.
Architecture: Top-down FPN + bottom-up path aggregation.
- Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone
- Output: P3, P4, P5 at same strides with fused multi-scale features
- All output channels unified to `out_channels`
Key design (from SCRFD paper):
- Lightweight PAFPN with configurable channel width
- Group Normalization (stable with small batch sizes, per TinaFace finding)
- NAS-searched channel width varies by model tier
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
class ConvGNReLU(nn.Module):
"""Conv + GroupNorm + ReLU."""
def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
stride: int = 1, padding: int = 1, groups: int = 1,
num_gn_groups: int = 16, use_relu: bool = True):
super().__init__()
# Ensure num_gn_groups divides out_ch
gn_groups = min(num_gn_groups, out_ch)
while out_ch % gn_groups != 0:
gn_groups -= 1
self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
groups=groups, bias=False)
self.gn = nn.GroupNorm(gn_groups, out_ch)
self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.gn(self.conv(x)))
class PAFPN(nn.Module):
"""
Path Aggregation Feature Pyramid Network.
Flow:
1. Lateral connections: 1Γ1 conv to unify channels
2. Top-down: upsample + add (FPN)
3. Bottom-up: downsample + add (PAN)
4. Output convs: 3Γ3 conv for anti-aliasing
"""
def __init__(self, in_channels: List[int], out_channels: int = 64,
num_extra_convs: int = 0, use_gn: bool = True):
super().__init__()
self.num_levels = len(in_channels)
self.out_channels = out_channels
# Lateral connections (1Γ1 conv to unify channels)
self.lateral_convs = nn.ModuleList()
for in_ch in in_channels:
self.lateral_convs.append(
ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn
else nn.Sequential(
nn.Conv2d(in_ch, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Top-down output convs (anti-aliasing after upsample+add)
self.td_convs = nn.ModuleList()
for _ in range(self.num_levels):
self.td_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Bottom-up downsample convs (stride-2)
self.bu_convs = nn.ModuleList()
for _ in range(self.num_levels - 1):
self.bu_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# Bottom-up output convs
self.bu_out_convs = nn.ModuleList()
for _ in range(self.num_levels):
self.bu_out_convs.append(
ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
else nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""
Args:
inputs: (C3, C4, C5) feature maps from backbone
Returns:
(P3, P4, P5) fused feature maps
"""
assert len(inputs) == self.num_levels
# 1. Lateral connections
laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)]
# 2. Top-down pathway (FPN)
for i in range(self.num_levels - 1, 0, -1):
up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:],
mode='nearest')
laterals[i-1] = laterals[i-1] + up
td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)]
# 3. Bottom-up pathway (PAN)
bu_outs = [td_outs[0]]
for i in range(self.num_levels - 1):
down = self.bu_convs[i](bu_outs[-1])
bu_outs.append(td_outs[i+1] + down)
# 4. Output convs
outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels))
return outputs
# ββββββββββββββββββββββββ Configuration presets ββββββββββββββββββββββββ
NECK_CONFIGS = {
'scrfd_34g': dict(out_channels=64),
'scrfd_10g': dict(out_channels=56),
'scrfd_2.5g': dict(out_channels=40),
'scrfd_0.5g': dict(out_channels=16),
}
def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN:
"""Build neck by model name."""
cfg = NECK_CONFIGS.get(name, {})
cfg.update(kwargs)
return PAFPN(in_channels, **cfg)
|