File size: 5,922 Bytes
a73c9ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD.

Architecture: Top-down FPN + bottom-up path aggregation.
- Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone
- Output: P3, P4, P5 at same strides with fused multi-scale features
- All output channels unified to `out_channels`

Key design (from SCRFD paper):
- Lightweight PAFPN with configurable channel width
- Group Normalization (stable with small batch sizes, per TinaFace finding)
- NAS-searched channel width varies by model tier
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple


class ConvGNReLU(nn.Module):
    """Conv + GroupNorm + ReLU."""

    def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
                 stride: int = 1, padding: int = 1, groups: int = 1,
                 num_gn_groups: int = 16, use_relu: bool = True):
        super().__init__()
        # Ensure num_gn_groups divides out_ch
        gn_groups = min(num_gn_groups, out_ch)
        while out_ch % gn_groups != 0:
            gn_groups -= 1

        self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
                              groups=groups, bias=False)
        self.gn = nn.GroupNorm(gn_groups, out_ch)
        self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.gn(self.conv(x)))


class PAFPN(nn.Module):
    """
    Path Aggregation Feature Pyramid Network.

    Flow:
    1. Lateral connections: 1Γ—1 conv to unify channels
    2. Top-down: upsample + add (FPN)
    3. Bottom-up: downsample + add (PAN)
    4. Output convs: 3Γ—3 conv for anti-aliasing
    """

    def __init__(self, in_channels: List[int], out_channels: int = 64,
                 num_extra_convs: int = 0, use_gn: bool = True):
        super().__init__()
        self.num_levels = len(in_channels)
        self.out_channels = out_channels

        # Lateral connections (1Γ—1 conv to unify channels)
        self.lateral_convs = nn.ModuleList()
        for in_ch in in_channels:
            self.lateral_convs.append(
                ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn
                else nn.Sequential(
                    nn.Conv2d(in_ch, out_channels, 1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

        # Top-down output convs (anti-aliasing after upsample+add)
        self.td_convs = nn.ModuleList()
        for _ in range(self.num_levels):
            self.td_convs.append(
                ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
                else nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

        # Bottom-up downsample convs (stride-2)
        self.bu_convs = nn.ModuleList()
        for _ in range(self.num_levels - 1):
            self.bu_convs.append(
                ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn
                else nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

        # Bottom-up output convs
        self.bu_out_convs = nn.ModuleList()
        for _ in range(self.num_levels):
            self.bu_out_convs.append(
                ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
                else nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Args:
            inputs: (C3, C4, C5) feature maps from backbone
        Returns:
            (P3, P4, P5) fused feature maps
        """
        assert len(inputs) == self.num_levels

        # 1. Lateral connections
        laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)]

        # 2. Top-down pathway (FPN)
        for i in range(self.num_levels - 1, 0, -1):
            up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:],
                               mode='nearest')
            laterals[i-1] = laterals[i-1] + up

        td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)]

        # 3. Bottom-up pathway (PAN)
        bu_outs = [td_outs[0]]
        for i in range(self.num_levels - 1):
            down = self.bu_convs[i](bu_outs[-1])
            bu_outs.append(td_outs[i+1] + down)

        # 4. Output convs
        outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels))
        return outputs


# ──────────────────────── Configuration presets ────────────────────────

NECK_CONFIGS = {
    'scrfd_34g': dict(out_channels=64),
    'scrfd_10g': dict(out_channels=56),
    'scrfd_2.5g': dict(out_channels=40),
    'scrfd_0.5g': dict(out_channels=16),
}


def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN:
    """Build neck by model name."""
    cfg = NECK_CONFIGS.get(name, {})
    cfg.update(kwargs)
    return PAFPN(in_channels, **cfg)