cledouxluma commited on
Commit
a73c9ef
Β·
verified Β·
1 Parent(s): 3007f1d

Upload models/neck.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/neck.py +162 -0
models/neck.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PAFPN (Path Aggregation Feature Pyramid Network) for SCRFD.
3
+
4
+ Architecture: Top-down FPN + bottom-up path aggregation.
5
+ - Input: C3 (stride 8), C4 (stride 16), C5 (stride 32) from backbone
6
+ - Output: P3, P4, P5 at same strides with fused multi-scale features
7
+ - All output channels unified to `out_channels`
8
+
9
+ Key design (from SCRFD paper):
10
+ - Lightweight PAFPN with configurable channel width
11
+ - Group Normalization (stable with small batch sizes, per TinaFace finding)
12
+ - NAS-searched channel width varies by model tier
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import List, Tuple
19
+
20
+
21
+ class ConvGNReLU(nn.Module):
22
+ """Conv + GroupNorm + ReLU."""
23
+
24
+ def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
25
+ stride: int = 1, padding: int = 1, groups: int = 1,
26
+ num_gn_groups: int = 16, use_relu: bool = True):
27
+ super().__init__()
28
+ # Ensure num_gn_groups divides out_ch
29
+ gn_groups = min(num_gn_groups, out_ch)
30
+ while out_ch % gn_groups != 0:
31
+ gn_groups -= 1
32
+
33
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
34
+ groups=groups, bias=False)
35
+ self.gn = nn.GroupNorm(gn_groups, out_ch)
36
+ self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return self.relu(self.gn(self.conv(x)))
40
+
41
+
42
+ class PAFPN(nn.Module):
43
+ """
44
+ Path Aggregation Feature Pyramid Network.
45
+
46
+ Flow:
47
+ 1. Lateral connections: 1Γ—1 conv to unify channels
48
+ 2. Top-down: upsample + add (FPN)
49
+ 3. Bottom-up: downsample + add (PAN)
50
+ 4. Output convs: 3Γ—3 conv for anti-aliasing
51
+ """
52
+
53
+ def __init__(self, in_channels: List[int], out_channels: int = 64,
54
+ num_extra_convs: int = 0, use_gn: bool = True):
55
+ super().__init__()
56
+ self.num_levels = len(in_channels)
57
+ self.out_channels = out_channels
58
+
59
+ # Lateral connections (1Γ—1 conv to unify channels)
60
+ self.lateral_convs = nn.ModuleList()
61
+ for in_ch in in_channels:
62
+ self.lateral_convs.append(
63
+ ConvGNReLU(in_ch, out_channels, 1, 1, 0) if use_gn
64
+ else nn.Sequential(
65
+ nn.Conv2d(in_ch, out_channels, 1, bias=False),
66
+ nn.BatchNorm2d(out_channels),
67
+ nn.ReLU(inplace=True)
68
+ )
69
+ )
70
+
71
+ # Top-down output convs (anti-aliasing after upsample+add)
72
+ self.td_convs = nn.ModuleList()
73
+ for _ in range(self.num_levels):
74
+ self.td_convs.append(
75
+ ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
76
+ else nn.Sequential(
77
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
78
+ nn.BatchNorm2d(out_channels),
79
+ nn.ReLU(inplace=True)
80
+ )
81
+ )
82
+
83
+ # Bottom-up downsample convs (stride-2)
84
+ self.bu_convs = nn.ModuleList()
85
+ for _ in range(self.num_levels - 1):
86
+ self.bu_convs.append(
87
+ ConvGNReLU(out_channels, out_channels, 3, 2, 1) if use_gn
88
+ else nn.Sequential(
89
+ nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False),
90
+ nn.BatchNorm2d(out_channels),
91
+ nn.ReLU(inplace=True)
92
+ )
93
+ )
94
+
95
+ # Bottom-up output convs
96
+ self.bu_out_convs = nn.ModuleList()
97
+ for _ in range(self.num_levels):
98
+ self.bu_out_convs.append(
99
+ ConvGNReLU(out_channels, out_channels, 3, 1, 1) if use_gn
100
+ else nn.Sequential(
101
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
102
+ nn.BatchNorm2d(out_channels),
103
+ nn.ReLU(inplace=True)
104
+ )
105
+ )
106
+
107
+ self._init_weights()
108
+
109
+ def _init_weights(self):
110
+ for m in self.modules():
111
+ if isinstance(m, nn.Conv2d):
112
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
113
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
114
+ nn.init.constant_(m.weight, 1)
115
+ nn.init.constant_(m.bias, 0)
116
+
117
+ def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
118
+ """
119
+ Args:
120
+ inputs: (C3, C4, C5) feature maps from backbone
121
+ Returns:
122
+ (P3, P4, P5) fused feature maps
123
+ """
124
+ assert len(inputs) == self.num_levels
125
+
126
+ # 1. Lateral connections
127
+ laterals = [self.lateral_convs[i](inputs[i]) for i in range(self.num_levels)]
128
+
129
+ # 2. Top-down pathway (FPN)
130
+ for i in range(self.num_levels - 1, 0, -1):
131
+ up = F.interpolate(laterals[i], size=laterals[i-1].shape[2:],
132
+ mode='nearest')
133
+ laterals[i-1] = laterals[i-1] + up
134
+
135
+ td_outs = [self.td_convs[i](laterals[i]) for i in range(self.num_levels)]
136
+
137
+ # 3. Bottom-up pathway (PAN)
138
+ bu_outs = [td_outs[0]]
139
+ for i in range(self.num_levels - 1):
140
+ down = self.bu_convs[i](bu_outs[-1])
141
+ bu_outs.append(td_outs[i+1] + down)
142
+
143
+ # 4. Output convs
144
+ outputs = tuple(self.bu_out_convs[i](bu_outs[i]) for i in range(self.num_levels))
145
+ return outputs
146
+
147
+
148
+ # ──────────────────────── Configuration presets ────────────────────────
149
+
150
+ NECK_CONFIGS = {
151
+ 'scrfd_34g': dict(out_channels=64),
152
+ 'scrfd_10g': dict(out_channels=56),
153
+ 'scrfd_2.5g': dict(out_channels=40),
154
+ 'scrfd_0.5g': dict(out_channels=16),
155
+ }
156
+
157
+
158
+ def build_neck(name: str, in_channels: List[int], **kwargs) -> PAFPN:
159
+ """Build neck by model name."""
160
+ cfg = NECK_CONFIGS.get(name, {})
161
+ cfg.update(kwargs)
162
+ return PAFPN(in_channels, **cfg)