cledouxluma commited on
Commit
6953619
Β·
verified Β·
1 Parent(s): afaa2cf

Upload models/detector.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/detector.py +419 -0
models/detector.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SCRFD Full Detector β€” Backbone + Neck + Head + Loss + Post-processing.
3
+
4
+ This is the main model class that ties together all components and provides:
5
+ 1. Training forward: returns losses dict
6
+ 2. Inference forward: returns detections (boxes, scores, landmarks)
7
+ 3. ONNX-exportable inference path
8
+
9
+ Model configurations (WiderFace Hard val / GFLOPs / FPS @VGA on V100):
10
+ - SCRFD-34GF: 85.2% / 34 GF / ~80 FPS (flagship quality)
11
+ - SCRFD-10GF: 83.1% / 10 GF / ~140 FPS (balanced)
12
+ - SCRFD-2.5GF: 77.9% / 2.5 GF / ~400 FPS (real-time)
13
+ - SCRFD-0.5GF: 68.5% / 0.5 GF / ~1000 FPS (mobile/edge)
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import List, Tuple, Dict, Optional
20
+ import math
21
+
22
+ from .backbone import SCRFDBackbone, build_backbone
23
+ from .neck import PAFPN, build_neck
24
+ from .head import SCRFDHead, build_head
25
+ from .anchor import AnchorGenerator, ATSSAssigner
26
+ from .losses import GFocalLoss, DIoULoss, FocalLoss, LandmarkLoss
27
+
28
+
29
+ class SCRFD(nn.Module):
30
+ """
31
+ Sample and Computation Redistribution Face Detector.
32
+
33
+ Complete pipeline: backbone β†’ PAFPN β†’ shared head β†’ anchors β†’ losses/NMS
34
+ """
35
+
36
+ def __init__(self,
37
+ backbone: SCRFDBackbone,
38
+ neck: PAFPN,
39
+ head: SCRFDHead,
40
+ anchor_generator: AnchorGenerator,
41
+ assigner: ATSSAssigner,
42
+ strides: List[int] = [8, 16, 32],
43
+ score_threshold: float = 0.3,
44
+ nms_threshold: float = 0.4,
45
+ max_detections: int = 750,
46
+ use_gfl: bool = True,
47
+ cls_weight: float = 1.0,
48
+ reg_weight: float = 2.0,
49
+ lmk_weight: float = 0.1):
50
+ super().__init__()
51
+ self.backbone = backbone
52
+ self.neck = neck
53
+ self.head = head
54
+ self.anchor_gen = anchor_generator
55
+ self.assigner = assigner
56
+ self.strides = strides
57
+ self.score_threshold = score_threshold
58
+ self.nms_threshold = nms_threshold
59
+ self.max_detections = max_detections
60
+ self.use_gfl = use_gfl
61
+
62
+ # Loss functions
63
+ self.cls_loss_fn = GFocalLoss(beta=2.0) if use_gfl else FocalLoss()
64
+ self.reg_loss_fn = DIoULoss()
65
+ self.lmk_loss_fn = LandmarkLoss() if head.use_landmarks else None
66
+
67
+ # Loss weights
68
+ self.cls_weight = cls_weight
69
+ self.reg_weight = reg_weight
70
+ self.lmk_weight = lmk_weight
71
+
72
+ def forward(self, images: torch.Tensor,
73
+ targets: Optional[List[Dict]] = None) -> Dict:
74
+ """
75
+ Args:
76
+ images: [B, 3, H, W] batch of images (normalized)
77
+ targets: List of dicts with keys:
78
+ 'boxes': [M, 4] face boxes (x1, y1, x2, y2)
79
+ 'labels': [M] labels (all 1)
80
+ 'landmarks': [M, 10] optional landmarks
81
+ When None, runs inference.
82
+
83
+ Returns:
84
+ Training: dict of losses
85
+ Inference: list of dicts with 'boxes', 'scores', 'landmarks'
86
+ """
87
+ # Feature extraction
88
+ features = self.backbone(images)
89
+ features = self.neck(features)
90
+ head_out = self.head(features)
91
+
92
+ # Generate anchors
93
+ feat_sizes = [(f.shape[2], f.shape[3]) for f in features]
94
+ anchors_per_level = self.anchor_gen.grid_anchors(feat_sizes, images.device)
95
+ num_anchors_per_level = [a.shape[0] for a in anchors_per_level]
96
+
97
+ if targets is not None:
98
+ return self._compute_loss(head_out, anchors_per_level,
99
+ num_anchors_per_level, targets, images.shape)
100
+ else:
101
+ return self._inference(head_out, anchors_per_level, images.shape)
102
+
103
+ def _compute_loss(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
104
+ num_per_level: List[int], targets: List[Dict],
105
+ img_shape: Tuple) -> Dict:
106
+ """Compute training losses."""
107
+ device = anchors_per_level[0].device
108
+ batch_size = len(targets)
109
+
110
+ # Flatten predictions across levels
111
+ all_cls = []
112
+ all_reg = []
113
+ all_lmk = []
114
+ for i in range(len(self.strides)):
115
+ B, _, H, W = head_out['cls_scores'][i].shape
116
+ cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1)
117
+ reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4)
118
+ all_cls.append(cls)
119
+ all_reg.append(reg)
120
+ if self.head.use_landmarks and 'lmk_preds' in head_out:
121
+ lmk = head_out['lmk_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 10)
122
+ all_lmk.append(lmk)
123
+
124
+ all_cls = torch.cat(all_cls, dim=1) # [B, N, 1]
125
+ all_reg = torch.cat(all_reg, dim=1) # [B, N, 4]
126
+ all_anchors = torch.cat(anchors_per_level, dim=0) # [N, 4]
127
+
128
+ has_lmk = len(all_lmk) > 0
129
+ if has_lmk:
130
+ all_lmk = torch.cat(all_lmk, dim=1)
131
+
132
+ total_cls_loss = torch.tensor(0.0, device=device)
133
+ total_reg_loss = torch.tensor(0.0, device=device)
134
+ total_lmk_loss = torch.tensor(0.0, device=device)
135
+ num_pos = 0
136
+
137
+ for b in range(batch_size):
138
+ gt_boxes = targets[b]['boxes']
139
+ gt_labels = targets[b].get('labels',
140
+ torch.ones(gt_boxes.shape[0], dtype=torch.long, device=device))
141
+
142
+ # ATSS matching
143
+ assigned_labels, assigned_gt_inds = self.assigner.assign(
144
+ all_anchors, gt_boxes, gt_labels, num_per_level
145
+ )
146
+
147
+ pos_mask = assigned_labels > 0
148
+ num_pos += pos_mask.sum().item()
149
+
150
+ # Classification loss (all anchors)
151
+ if self.use_gfl:
152
+ # GFL: positive target = IoU, negative target = 0
153
+ cls_targets = torch.zeros(all_anchors.shape[0], device=device)
154
+ if pos_mask.any():
155
+ pos_anchors = all_anchors[pos_mask]
156
+ pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]
157
+ pos_ious = self._compute_iou_single(pos_anchors, pos_gt)
158
+ cls_targets[pos_mask] = pos_ious
159
+ total_cls_loss += self.cls_loss_fn(
160
+ all_cls[b].squeeze(-1), cls_targets
161
+ )
162
+ else:
163
+ total_cls_loss += self.cls_loss_fn(
164
+ all_cls[b].squeeze(-1), (assigned_labels > 0).float()
165
+ )
166
+
167
+ # Box regression loss (positive anchors only)
168
+ if pos_mask.any():
169
+ pos_reg = all_reg[b][pos_mask]
170
+ pos_anchors = all_anchors[pos_mask]
171
+ pos_gt = gt_boxes[assigned_gt_inds[pos_mask]]
172
+
173
+ # Decode predictions to absolute boxes
174
+ pred_boxes = self._decode_boxes(pos_anchors, pos_reg)
175
+ total_reg_loss += self.reg_loss_fn(pred_boxes, pos_gt)
176
+
177
+ # Landmark loss
178
+ if self.head.use_landmarks and 'landmarks' in targets[b] and has_lmk:
179
+ gt_lmk = targets[b]['landmarks']
180
+ pos_lmk_pred = all_lmk[b][pos_mask]
181
+ pos_lmk_gt = gt_lmk[assigned_gt_inds[pos_mask]]
182
+ # Decode landmarks relative to anchors
183
+ pred_lmk = self._decode_landmarks(pos_anchors, pos_lmk_pred)
184
+ total_lmk_loss += self.lmk_loss_fn(pred_lmk, pos_lmk_gt)
185
+
186
+ num_pos = max(num_pos, 1)
187
+ losses = {
188
+ 'cls_loss': self.cls_weight * total_cls_loss / batch_size,
189
+ 'reg_loss': self.reg_weight * total_reg_loss / batch_size,
190
+ }
191
+ if self.head.use_landmarks:
192
+ losses['lmk_loss'] = self.lmk_weight * total_lmk_loss / batch_size
193
+
194
+ losses['total_loss'] = sum(losses.values())
195
+ losses['num_pos'] = torch.tensor(num_pos, dtype=torch.float, device=device)
196
+ return losses
197
+
198
+ def _inference(self, head_out: Dict, anchors_per_level: List[torch.Tensor],
199
+ img_shape: Tuple) -> List[Dict]:
200
+ """Run inference with NMS."""
201
+ batch_size = head_out['cls_scores'][0].shape[0]
202
+ device = head_out['cls_scores'][0].device
203
+
204
+ results = []
205
+ for b in range(batch_size):
206
+ all_boxes = []
207
+ all_scores = []
208
+ all_lmk = []
209
+
210
+ for i in range(len(self.strides)):
211
+ cls = head_out['cls_scores'][i][b].permute(1, 2, 0).reshape(-1, 1).sigmoid()
212
+ reg = head_out['bbox_preds'][i][b].permute(1, 2, 0).reshape(-1, 4)
213
+ anchors = anchors_per_level[i]
214
+
215
+ # Filter by score threshold
216
+ scores = cls.squeeze(-1)
217
+ keep = scores > self.score_threshold
218
+ if keep.sum() == 0:
219
+ continue
220
+
221
+ scores = scores[keep]
222
+ reg = reg[keep]
223
+ anc = anchors[keep]
224
+
225
+ # Decode boxes
226
+ boxes = self._decode_boxes(anc, reg)
227
+
228
+ # Clamp to image boundaries
229
+ boxes[:, 0].clamp_(min=0)
230
+ boxes[:, 1].clamp_(min=0)
231
+ boxes[:, 2].clamp_(max=img_shape[3])
232
+ boxes[:, 3].clamp_(max=img_shape[2])
233
+
234
+ all_boxes.append(boxes)
235
+ all_scores.append(scores)
236
+
237
+ if self.head.use_landmarks and 'lmk_preds' in head_out:
238
+ lmk = head_out['lmk_preds'][i][b].permute(1, 2, 0).reshape(-1, 10)[keep]
239
+ lmk_decoded = self._decode_landmarks(anc, lmk)
240
+ all_lmk.append(lmk_decoded)
241
+
242
+ if not all_boxes:
243
+ results.append({
244
+ 'boxes': torch.empty(0, 4, device=device),
245
+ 'scores': torch.empty(0, device=device),
246
+ })
247
+ continue
248
+
249
+ all_boxes = torch.cat(all_boxes, dim=0)
250
+ all_scores = torch.cat(all_scores, dim=0)
251
+
252
+ # NMS
253
+ keep = self._nms(all_boxes, all_scores, self.nms_threshold)
254
+ keep = keep[:self.max_detections]
255
+
256
+ result = {
257
+ 'boxes': all_boxes[keep],
258
+ 'scores': all_scores[keep],
259
+ }
260
+ if all_lmk:
261
+ all_lmk = torch.cat(all_lmk, dim=0)
262
+ result['landmarks'] = all_lmk[keep]
263
+ results.append(result)
264
+
265
+ return results
266
+
267
+ def _decode_boxes(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
268
+ """Decode box predictions relative to anchors (distance-based)."""
269
+ anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
270
+ anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
271
+ anchor_w = anchors[:, 2] - anchors[:, 0]
272
+ anchor_h = anchors[:, 3] - anchors[:, 1]
273
+
274
+ x1 = anchor_cx - pred[:, 0] * anchor_w
275
+ y1 = anchor_cy - pred[:, 1] * anchor_h
276
+ x2 = anchor_cx + pred[:, 2] * anchor_w
277
+ y2 = anchor_cy + pred[:, 3] * anchor_h
278
+
279
+ return torch.stack([x1, y1, x2, y2], dim=1)
280
+
281
+ def _decode_landmarks(self, anchors: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
282
+ """Decode landmark predictions relative to anchors."""
283
+ anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
284
+ anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
285
+ anchor_w = anchors[:, 2] - anchors[:, 0]
286
+ anchor_h = anchors[:, 3] - anchors[:, 1]
287
+
288
+ decoded = pred.clone()
289
+ for i in range(5):
290
+ decoded[:, i*2] = anchor_cx + pred[:, i*2] * anchor_w
291
+ decoded[:, i*2+1] = anchor_cy + pred[:, i*2+1] * anchor_h
292
+ return decoded
293
+
294
+ @staticmethod
295
+ def _compute_iou_single(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
296
+ """Compute elementwise IoU between paired boxes. [N,4] Γ— [N,4] β†’ [N]"""
297
+ inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
298
+ inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
299
+ inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
300
+ inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])
301
+ inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
302
+
303
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
304
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
305
+ union = area1 + area2 - inter
306
+ return inter / (union + 1e-6)
307
+
308
+ @staticmethod
309
+ def _nms(boxes: torch.Tensor, scores: torch.Tensor,
310
+ threshold: float) -> torch.Tensor:
311
+ """Non-Maximum Suppression. Returns kept indices."""
312
+ if boxes.shape[0] == 0:
313
+ return torch.empty(0, dtype=torch.long, device=boxes.device)
314
+
315
+ # Use torchvision NMS if available, else pure PyTorch
316
+ try:
317
+ from torchvision.ops import nms
318
+ return nms(boxes, scores, threshold)
319
+ except ImportError:
320
+ pass
321
+
322
+ # Pure PyTorch NMS fallback
323
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
324
+ areas = (x2 - x1) * (y2 - y1)
325
+ order = scores.argsort(descending=True)
326
+ keep = []
327
+
328
+ while order.numel() > 0:
329
+ i = order[0].item()
330
+ keep.append(i)
331
+ if order.numel() == 1:
332
+ break
333
+
334
+ xx1 = torch.max(x1[i], x1[order[1:]])
335
+ yy1 = torch.max(y1[i], y1[order[1:]])
336
+ xx2 = torch.min(x2[i], x2[order[1:]])
337
+ yy2 = torch.min(y2[i], y2[order[1:]])
338
+ inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
339
+ iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
340
+ mask = iou <= threshold
341
+ order = order[1:][mask]
342
+
343
+ return torch.tensor(keep, dtype=torch.long, device=boxes.device)
344
+
345
+
346
+ # ──────────────────────── Model Builder ────────────────────────
347
+
348
+ MODEL_CONFIGS = {
349
+ 'scrfd_34g': {
350
+ 'backbone': 'scrfd_34g',
351
+ 'neck_out': 64,
352
+ 'head_feat': 64,
353
+ 'head_convs': 3,
354
+ },
355
+ 'scrfd_10g': {
356
+ 'backbone': 'scrfd_10g',
357
+ 'neck_out': 56,
358
+ 'head_feat': 56,
359
+ 'head_convs': 2,
360
+ },
361
+ 'scrfd_2.5g': {
362
+ 'backbone': 'scrfd_2.5g',
363
+ 'neck_out': 40,
364
+ 'head_feat': 40,
365
+ 'head_convs': 2,
366
+ },
367
+ 'scrfd_0.5g': {
368
+ 'backbone': 'scrfd_0.5g',
369
+ 'neck_out': 16,
370
+ 'head_feat': 16,
371
+ 'head_convs': 2,
372
+ },
373
+ }
374
+
375
+
376
+ def build_detector(name: str, use_landmarks: bool = False,
377
+ score_threshold: float = 0.3,
378
+ nms_threshold: float = 0.4,
379
+ **kwargs) -> SCRFD:
380
+ """
381
+ Build a complete SCRFD detector by name.
382
+
383
+ Args:
384
+ name: Model name ('scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g')
385
+ use_landmarks: Enable 5-point landmark prediction
386
+ score_threshold: Detection confidence threshold
387
+ nms_threshold: NMS IoU threshold
388
+
389
+ Returns:
390
+ Complete SCRFD detector ready for training or inference
391
+ """
392
+ if name not in MODEL_CONFIGS:
393
+ raise ValueError(f"Unknown model: {name}. Options: {list(MODEL_CONFIGS.keys())}")
394
+
395
+ cfg = MODEL_CONFIGS[name]
396
+
397
+ backbone = build_backbone(cfg['backbone'])
398
+ neck = PAFPN(backbone.out_channels, out_channels=cfg['neck_out'])
399
+ head = SCRFDHead(
400
+ in_channels=cfg['neck_out'],
401
+ feat_channels=cfg['head_feat'],
402
+ stacked_convs=cfg['head_convs'],
403
+ use_landmarks=use_landmarks,
404
+ )
405
+ anchor_gen = AnchorGenerator()
406
+ assigner = ATSSAssigner(topk=9)
407
+
408
+ model = SCRFD(
409
+ backbone=backbone,
410
+ neck=neck,
411
+ head=head,
412
+ anchor_generator=anchor_gen,
413
+ assigner=assigner,
414
+ score_threshold=score_threshold,
415
+ nms_threshold=nms_threshold,
416
+ **kwargs,
417
+ )
418
+
419
+ return model