cledouxluma commited on
Commit
afaa2cf
Β·
verified Β·
1 Parent(s): 2cc83d8

Upload models/head.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/head.py +176 -0
models/head.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SCRFD Detection Head β€” shared-weight, multi-task, scale-aware.
3
+
4
+ Design from SCRFD paper:
5
+ - Weight sharing across pyramid levels (parameter-efficient)
6
+ - GroupNorm for batch-size independence
7
+ - Separate cls and reg branches (GFL-style)
8
+ - Optional landmark branch (RetinaFace-style 5-point)
9
+
10
+ Output per anchor:
11
+ - Classification: 1 score (face quality score via GFL)
12
+ - Box regression: 4 values (distance from anchor center to box edges)
13
+ - Landmarks (optional): 10 values (5 x,y offsets from anchor center)
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import List, Tuple, Optional
20
+ import math
21
+
22
+
23
+ class SCRFDHead(nn.Module):
24
+ """
25
+ Shared detection head applied to each FPN level.
26
+
27
+ Args:
28
+ in_channels: Input channels from neck
29
+ num_classes: Number of classes (1 for face detection)
30
+ num_anchors: Anchors per spatial location per level
31
+ feat_channels: Hidden channel width in head convolutions
32
+ stacked_convs: Number of stacked 3Γ—3 convs in each branch
33
+ use_gn: Use GroupNorm (vs BatchNorm)
34
+ use_landmarks: Enable 5-point landmark regression branch
35
+ """
36
+
37
+ def __init__(self,
38
+ in_channels: int = 64,
39
+ num_classes: int = 1,
40
+ num_anchors: int = 2,
41
+ feat_channels: int = 64,
42
+ stacked_convs: int = 2,
43
+ use_gn: bool = True,
44
+ use_landmarks: bool = False):
45
+ super().__init__()
46
+ self.num_classes = num_classes
47
+ self.num_anchors = num_anchors
48
+ self.use_landmarks = use_landmarks
49
+
50
+ # Classification branch
51
+ cls_convs = []
52
+ for i in range(stacked_convs):
53
+ ch_in = in_channels if i == 0 else feat_channels
54
+ cls_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
55
+ if use_gn:
56
+ gn_groups = min(16, feat_channels)
57
+ while feat_channels % gn_groups != 0:
58
+ gn_groups -= 1
59
+ cls_convs.append(nn.GroupNorm(gn_groups, feat_channels))
60
+ else:
61
+ cls_convs.append(nn.BatchNorm2d(feat_channels))
62
+ cls_convs.append(nn.ReLU(inplace=True))
63
+ self.cls_convs = nn.Sequential(*cls_convs)
64
+ self.cls_out = nn.Conv2d(feat_channels, num_anchors * num_classes, 3, 1, 1)
65
+
66
+ # Box regression branch
67
+ reg_convs = []
68
+ for i in range(stacked_convs):
69
+ ch_in = in_channels if i == 0 else feat_channels
70
+ reg_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
71
+ if use_gn:
72
+ gn_groups = min(16, feat_channels)
73
+ while feat_channels % gn_groups != 0:
74
+ gn_groups -= 1
75
+ reg_convs.append(nn.GroupNorm(gn_groups, feat_channels))
76
+ else:
77
+ reg_convs.append(nn.BatchNorm2d(feat_channels))
78
+ reg_convs.append(nn.ReLU(inplace=True))
79
+ self.reg_convs = nn.Sequential(*reg_convs)
80
+ self.reg_out = nn.Conv2d(feat_channels, num_anchors * 4, 3, 1, 1)
81
+
82
+ # Landmark branch (optional)
83
+ if use_landmarks:
84
+ lmk_convs = []
85
+ for i in range(stacked_convs):
86
+ ch_in = in_channels if i == 0 else feat_channels
87
+ lmk_convs.append(nn.Conv2d(ch_in, feat_channels, 3, 1, 1, bias=False))
88
+ if use_gn:
89
+ gn_groups = min(16, feat_channels)
90
+ while feat_channels % gn_groups != 0:
91
+ gn_groups -= 1
92
+ lmk_convs.append(nn.GroupNorm(gn_groups, feat_channels))
93
+ else:
94
+ lmk_convs.append(nn.BatchNorm2d(feat_channels))
95
+ lmk_convs.append(nn.ReLU(inplace=True))
96
+ self.lmk_convs = nn.Sequential(*lmk_convs)
97
+ self.lmk_out = nn.Conv2d(feat_channels, num_anchors * 10, 3, 1, 1)
98
+
99
+ self._init_weights()
100
+
101
+ def _init_weights(self):
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, std=0.01)
105
+ if m.bias is not None:
106
+ nn.init.constant_(m.bias, 0)
107
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
108
+ nn.init.constant_(m.weight, 1)
109
+ nn.init.constant_(m.bias, 0)
110
+
111
+ # Initialize cls bias for focal loss (prevents initial instability)
112
+ # Prior probability = 0.01
113
+ prior_prob = 0.01
114
+ bias_init = -math.log((1 - prior_prob) / prior_prob)
115
+ nn.init.constant_(self.cls_out.bias, bias_init)
116
+
117
+ def forward_single(self, x: torch.Tensor) -> dict:
118
+ """Forward pass for a single FPN level."""
119
+ cls_feat = self.cls_convs(x)
120
+ cls_score = self.cls_out(cls_feat) # [B, A*C, H, W]
121
+
122
+ reg_feat = self.reg_convs(x)
123
+ bbox_pred = self.reg_out(reg_feat) # [B, A*4, H, W]
124
+
125
+ result = {'cls_score': cls_score, 'bbox_pred': bbox_pred}
126
+
127
+ if self.use_landmarks:
128
+ lmk_feat = self.lmk_convs(x)
129
+ lmk_pred = self.lmk_out(lmk_feat) # [B, A*10, H, W]
130
+ result['lmk_pred'] = lmk_pred
131
+
132
+ return result
133
+
134
+ def forward(self, features: Tuple[torch.Tensor, ...]) -> dict:
135
+ """
136
+ Forward on all FPN levels.
137
+
138
+ Args:
139
+ features: (P3, P4, P5) from neck
140
+
141
+ Returns:
142
+ dict with keys 'cls_scores', 'bbox_preds', optionally 'lmk_preds'
143
+ Each value is a list of tensors, one per level.
144
+ """
145
+ cls_scores = []
146
+ bbox_preds = []
147
+ lmk_preds = []
148
+
149
+ for feat in features:
150
+ out = self.forward_single(feat)
151
+ cls_scores.append(out['cls_score'])
152
+ bbox_preds.append(out['bbox_pred'])
153
+ if self.use_landmarks:
154
+ lmk_preds.append(out['lmk_pred'])
155
+
156
+ result = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds}
157
+ if self.use_landmarks:
158
+ result['lmk_preds'] = lmk_preds
159
+ return result
160
+
161
+
162
+ # ──────────────────────── Configuration presets ────────────────────────
163
+
164
+ HEAD_CONFIGS = {
165
+ 'scrfd_34g': dict(feat_channels=64, stacked_convs=3),
166
+ 'scrfd_10g': dict(feat_channels=56, stacked_convs=2),
167
+ 'scrfd_2.5g': dict(feat_channels=40, stacked_convs=2),
168
+ 'scrfd_0.5g': dict(feat_channels=16, stacked_convs=2),
169
+ }
170
+
171
+
172
+ def build_head(name: str, in_channels: int, **kwargs) -> SCRFDHead:
173
+ """Build detection head by model name."""
174
+ cfg = HEAD_CONFIGS.get(name, {})
175
+ cfg.update(kwargs)
176
+ return SCRFDHead(in_channels=in_channels, **cfg)