Upload models/backbone.py with huggingface_hub
Browse files- models/backbone.py +179 -0
models/backbone.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SCRFD Backbones β NAS-searched ResNet-style with computation redistribution.
|
| 3 |
+
|
| 4 |
+
Key insight from paper: Standard classification backbones over-invest compute in
|
| 5 |
+
C5 features (stride 32), which are useless for tiny face detection. SCRFD
|
| 6 |
+
redistributes compute toward earlier stages (C2/C3) for stride-8 feature quality.
|
| 7 |
+
|
| 8 |
+
Configurations (from paper Table 3):
|
| 9 |
+
- SCRFD-34GF: stages=[3,12,28,3], widths=[56,88,248,304], groups=[1,1,1,1]
|
| 10 |
+
- SCRFD-10GF: stages=[3,10,16,3], widths=[36,64,144,224], groups=[1,1,1,1]
|
| 11 |
+
- SCRFD-2.5GF: stages=[2,4,4,3], widths=[24,48,96,160], groups=[1,1,1,1]
|
| 12 |
+
- SCRFD-0.5GF: stages=[2,2,4,2], widths=[16,32,64,128], groups=[1,1,1,1]
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from typing import List, Tuple, Dict, Optional
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ConvBNReLU(nn.Module):
|
| 23 |
+
"""Conv + BatchNorm + ReLU building block."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, in_ch: int, out_ch: int, kernel: int = 3,
|
| 26 |
+
stride: int = 1, padding: int = 1, groups: int = 1,
|
| 27 |
+
use_relu: bool = True):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding,
|
| 30 |
+
groups=groups, bias=False)
|
| 31 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 32 |
+
self.relu = nn.ReLU(inplace=True) if use_relu else nn.Identity()
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
return self.relu(self.bn(self.conv(x)))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock(nn.Module):
|
| 39 |
+
"""ResNet BasicBlock with optional group convolution."""
|
| 40 |
+
expansion = 1
|
| 41 |
+
|
| 42 |
+
def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
|
| 43 |
+
groups: int = 1, downsample: Optional[nn.Module] = None):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.conv1 = ConvBNReLU(in_ch, out_ch, 3, stride, 1, groups)
|
| 46 |
+
self.conv2 = ConvBNReLU(out_ch, out_ch, 3, 1, 1, groups, use_relu=False)
|
| 47 |
+
self.downsample = downsample
|
| 48 |
+
self.relu = nn.ReLU(inplace=True)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
identity = x
|
| 52 |
+
out = self.conv1(x)
|
| 53 |
+
out = self.conv2(out)
|
| 54 |
+
if self.downsample is not None:
|
| 55 |
+
identity = self.downsample(x)
|
| 56 |
+
out += identity
|
| 57 |
+
return self.relu(out)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BottleneckBlock(nn.Module):
|
| 61 |
+
"""ResNet Bottleneck with optional group convolution."""
|
| 62 |
+
expansion = 4
|
| 63 |
+
|
| 64 |
+
def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
|
| 65 |
+
groups: int = 1, downsample: Optional[nn.Module] = None):
|
| 66 |
+
super().__init__()
|
| 67 |
+
mid_ch = out_ch # bottleneck width
|
| 68 |
+
self.conv1 = ConvBNReLU(in_ch, mid_ch, 1, 1, 0)
|
| 69 |
+
self.conv2 = ConvBNReLU(mid_ch, mid_ch, 3, stride, 1, groups)
|
| 70 |
+
self.conv3 = ConvBNReLU(mid_ch, out_ch * self.expansion, 1, 1, 0, use_relu=False)
|
| 71 |
+
self.downsample = downsample
|
| 72 |
+
self.relu = nn.ReLU(inplace=True)
|
| 73 |
+
|
| 74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
identity = x
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.conv2(out)
|
| 78 |
+
out = self.conv3(out)
|
| 79 |
+
if self.downsample is not None:
|
| 80 |
+
identity = self.downsample(x)
|
| 81 |
+
out += identity
|
| 82 |
+
return self.relu(out)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class SCRFDBackbone(nn.Module):
|
| 86 |
+
"""
|
| 87 |
+
SCRFD backbone with NAS-searched stage depths and widths.
|
| 88 |
+
|
| 89 |
+
For SCRFD, we use BasicBlock (expansion=1) since the searched widths
|
| 90 |
+
already account for channel capacity β no need for bottleneck expansion.
|
| 91 |
+
|
| 92 |
+
Returns feature maps at strides [8, 16, 32] (C3, C4, C5).
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, stages: List[int], widths: List[int],
|
| 96 |
+
groups: List[int] = None, in_channels: int = 3,
|
| 97 |
+
block_type: str = 'basic'):
|
| 98 |
+
super().__init__()
|
| 99 |
+
assert len(stages) == 4 and len(widths) == 4
|
| 100 |
+
|
| 101 |
+
if groups is None:
|
| 102 |
+
groups = [1, 1, 1, 1]
|
| 103 |
+
|
| 104 |
+
Block = BasicBlock if block_type == 'basic' else BottleneckBlock
|
| 105 |
+
|
| 106 |
+
# Stem: stride 2 conv + stride 2 maxpool β effective stride 4
|
| 107 |
+
self.stem = nn.Sequential(
|
| 108 |
+
ConvBNReLU(in_channels, widths[0], 3, 2, 1),
|
| 109 |
+
ConvBNReLU(widths[0], widths[0], 3, 1, 1),
|
| 110 |
+
nn.MaxPool2d(3, 2, 1),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Stage 1: stride 1 (output stride = 4)
|
| 114 |
+
self.layer1 = self._make_layer(Block, widths[0], widths[0], stages[0],
|
| 115 |
+
stride=1, groups=groups[0])
|
| 116 |
+
# Stage 2: stride 2 (output stride = 8) β C3
|
| 117 |
+
self.layer2 = self._make_layer(Block, widths[0], widths[1], stages[1],
|
| 118 |
+
stride=2, groups=groups[1])
|
| 119 |
+
# Stage 3: stride 2 (output stride = 16) β C4
|
| 120 |
+
self.layer3 = self._make_layer(Block, widths[1], widths[2], stages[2],
|
| 121 |
+
stride=2, groups=groups[2])
|
| 122 |
+
# Stage 4: stride 2 (output stride = 32) β C5
|
| 123 |
+
self.layer4 = self._make_layer(Block, widths[2], widths[3], stages[3],
|
| 124 |
+
stride=2, groups=groups[3])
|
| 125 |
+
|
| 126 |
+
self.out_channels = [widths[1], widths[2], widths[3]]
|
| 127 |
+
self._init_weights()
|
| 128 |
+
|
| 129 |
+
def _make_layer(self, block, in_ch: int, out_ch: int, num_blocks: int,
|
| 130 |
+
stride: int = 1, groups: int = 1) -> nn.Sequential:
|
| 131 |
+
downsample = None
|
| 132 |
+
if stride != 1 or in_ch != out_ch * block.expansion:
|
| 133 |
+
downsample = nn.Sequential(
|
| 134 |
+
nn.Conv2d(in_ch, out_ch * block.expansion, 1, stride, bias=False),
|
| 135 |
+
nn.BatchNorm2d(out_ch * block.expansion),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
layers = [block(in_ch, out_ch, stride, groups, downsample)]
|
| 139 |
+
in_ch = out_ch * block.expansion
|
| 140 |
+
for _ in range(1, num_blocks):
|
| 141 |
+
layers.append(block(in_ch, out_ch, 1, groups))
|
| 142 |
+
return nn.Sequential(*layers)
|
| 143 |
+
|
| 144 |
+
def _init_weights(self):
|
| 145 |
+
for m in self.modules():
|
| 146 |
+
if isinstance(m, nn.Conv2d):
|
| 147 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 148 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 149 |
+
nn.init.constant_(m.weight, 1)
|
| 150 |
+
nn.init.constant_(m.bias, 0)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 153 |
+
x = self.stem(x)
|
| 154 |
+
c2 = self.layer1(x) # stride 4
|
| 155 |
+
c3 = self.layer2(c2) # stride 8
|
| 156 |
+
c4 = self.layer3(c3) # stride 16
|
| 157 |
+
c5 = self.layer4(c4) # stride 32
|
| 158 |
+
return c3, c4, c5
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ββββββββββββββββββββββββ Configuration presets ββββββββββββββββββββββββ
|
| 162 |
+
|
| 163 |
+
BACKBONE_CONFIGS = {
|
| 164 |
+
'scrfd_34g': dict(stages=[3, 12, 28, 3], widths=[56, 88, 248, 304]),
|
| 165 |
+
'scrfd_10g': dict(stages=[3, 10, 16, 3], widths=[36, 64, 144, 224]),
|
| 166 |
+
'scrfd_2.5g': dict(stages=[2, 4, 4, 3], widths=[24, 48, 96, 160]),
|
| 167 |
+
'scrfd_0.5g': dict(stages=[2, 2, 4, 2], widths=[16, 32, 64, 128]),
|
| 168 |
+
# ResNet variants for comparison
|
| 169 |
+
'resnet50': dict(stages=[3, 4, 6, 3], widths=[64, 128, 256, 512], block_type='bottleneck'),
|
| 170 |
+
'resnet18': dict(stages=[2, 2, 2, 2], widths=[64, 128, 256, 512]),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def build_backbone(name: str, **kwargs) -> SCRFDBackbone:
|
| 175 |
+
"""Build a backbone by name."""
|
| 176 |
+
if name not in BACKBONE_CONFIGS:
|
| 177 |
+
raise ValueError(f"Unknown backbone: {name}. Options: {list(BACKBONE_CONFIGS.keys())}")
|
| 178 |
+
cfg = {**BACKBONE_CONFIGS[name], **kwargs}
|
| 179 |
+
return SCRFDBackbone(**cfg)
|