cledouxluma commited on
Commit
afda97e
·
verified ·
1 Parent(s): 030da96

Upload engine/video_detector.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. engine/video_detector.py +372 -0
engine/video_detector.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video Face Detector — End-to-end video inference with tracking and smoothing.
3
+
4
+ Combines:
5
+ 1. SCRFD detector (per-frame face detection)
6
+ 2. ByteTrack tracker (cross-frame identity association)
7
+ 3. Temporal smoother (jitter reduction)
8
+ 4. Optional keyframe strategy (run full detection every N frames,
9
+ track-only on intermediate frames for speed)
10
+
11
+ Supports:
12
+ - Live webcam streams
13
+ - Video files (MP4, AVI, etc.)
14
+ - RTSP/RTMP streams
15
+ - Image directory sequences
16
+ - ONNX runtime for deployment
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import numpy as np
22
+ import cv2
23
+ from typing import Optional, Callable, List, Dict, Union
24
+ from dataclasses import dataclass
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from .tracker import ByteTracker, Track
30
+ from .temporal import TemporalSmoother
31
+
32
+
33
+ @dataclass
34
+ class FaceDetection:
35
+ """Single face detection result."""
36
+ track_id: int
37
+ bbox: np.ndarray # [x1, y1, x2, y2]
38
+ score: float
39
+ landmarks: Optional[np.ndarray] = None # [10] = 5 x (x, y)
40
+ is_confirmed: bool = True
41
+
42
+
43
+ class VideoFaceDetector:
44
+ """
45
+ Production video face detection pipeline.
46
+
47
+ Usage:
48
+ detector = VideoFaceDetector(model_path='scrfd_34g.pth', model_name='scrfd_34g')
49
+ for result in detector.process_video('input.mp4'):
50
+ for face in result['faces']:
51
+ print(f"Track {face.track_id}: bbox={face.bbox}, score={face.score:.2f}")
52
+
53
+ Args:
54
+ model: SCRFD model instance (or None to use ONNX)
55
+ model_path: Path to checkpoint (.pth) or ONNX model (.onnx)
56
+ model_name: Model variant name for building from scratch
57
+ device: 'cuda' or 'cpu'
58
+ score_threshold: Min detection confidence
59
+ nms_threshold: NMS IoU threshold
60
+ input_size: Model input resolution
61
+ use_tracking: Enable ByteTrack temporal tracking
62
+ use_smoothing: Enable EMA temporal smoothing
63
+ keyframe_interval: Run full detection every N frames (0=every frame)
64
+ """
65
+
66
+ def __init__(self,
67
+ model=None,
68
+ model_path: Optional[str] = None,
69
+ model_name: str = 'scrfd_34g',
70
+ device: str = 'cuda',
71
+ score_threshold: float = 0.3,
72
+ nms_threshold: float = 0.4,
73
+ input_size: int = 640,
74
+ use_tracking: bool = True,
75
+ use_smoothing: bool = True,
76
+ keyframe_interval: int = 0):
77
+
78
+ self.device = device
79
+ self.input_size = input_size
80
+ self.score_threshold = score_threshold
81
+ self.use_tracking = use_tracking
82
+ self.use_smoothing = use_smoothing
83
+ self.keyframe_interval = keyframe_interval
84
+ self.mean = np.array([104.0, 117.0, 123.0], dtype=np.float32)
85
+
86
+ # Load model
87
+ self.onnx_session = None
88
+ if model is not None:
89
+ self.model = model
90
+ elif model_path and model_path.endswith('.onnx'):
91
+ self._load_onnx(model_path)
92
+ self.model = None
93
+ else:
94
+ from models.detector import build_detector
95
+ self.model = build_detector(
96
+ model_name,
97
+ score_threshold=score_threshold,
98
+ nms_threshold=nms_threshold,
99
+ )
100
+ if model_path:
101
+ checkpoint = torch.load(model_path, map_location='cpu')
102
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
103
+ self.model.load_state_dict(state_dict, strict=False)
104
+
105
+ self.model.to(device)
106
+ self.model.eval()
107
+
108
+ # Initialize tracker and smoother
109
+ self.tracker = ByteTracker() if use_tracking else None
110
+ self.smoother = TemporalSmoother() if use_smoothing else None
111
+
112
+ self._frame_count = 0
113
+ self._last_detections = []
114
+
115
+ def _load_onnx(self, model_path: str):
116
+ """Load ONNX model for deployment inference."""
117
+ try:
118
+ import onnxruntime as ort
119
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
120
+ if self.device == 'cpu':
121
+ providers = ['CPUExecutionProvider']
122
+ self.onnx_session = ort.InferenceSession(model_path, providers=providers)
123
+ except ImportError:
124
+ raise ImportError("onnxruntime required for ONNX inference: pip install onnxruntime-gpu")
125
+
126
+ @torch.no_grad()
127
+ def detect_frame(self, frame: np.ndarray) -> List[FaceDetection]:
128
+ """
129
+ Detect faces in a single frame.
130
+
131
+ Args:
132
+ frame: BGR image (OpenCV format) or RGB numpy array
133
+
134
+ Returns:
135
+ List of FaceDetection objects
136
+ """
137
+ self._frame_count += 1
138
+
139
+ # Keyframe strategy: skip detection on non-keyframes
140
+ if (self.keyframe_interval > 0 and
141
+ self._frame_count % self.keyframe_interval != 1 and
142
+ self._frame_count > 1):
143
+ # Use tracker prediction only
144
+ if self.tracker:
145
+ tracks = self.tracker.update(
146
+ np.empty((0, 4)), np.empty(0), None
147
+ )
148
+ return self._tracks_to_detections(tracks)
149
+ return self._last_detections
150
+
151
+ # Preprocess
152
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if frame.shape[2] == 3 else frame
153
+ h_orig, w_orig = rgb.shape[:2]
154
+ img, scale, pad = self._preprocess(rgb)
155
+
156
+ # Run detection
157
+ if self.onnx_session:
158
+ boxes, scores, landmarks = self._infer_onnx(img)
159
+ else:
160
+ boxes, scores, landmarks = self._infer_pytorch(img)
161
+
162
+ # Rescale to original image coordinates
163
+ if len(boxes) > 0:
164
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad[0]) / scale
165
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad[1]) / scale
166
+ boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, w_orig)
167
+ boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, h_orig)
168
+
169
+ if landmarks is not None and len(landmarks) > 0:
170
+ for i in range(5):
171
+ landmarks[:, i*2] = (landmarks[:, i*2] - pad[0]) / scale
172
+ landmarks[:, i*2+1] = (landmarks[:, i*2+1] - pad[1]) / scale
173
+
174
+ # Tracking
175
+ if self.use_tracking and self.tracker:
176
+ lmk = landmarks if landmarks is not None else None
177
+ tracks = self.tracker.update(boxes, scores, lmk)
178
+ detections = self._tracks_to_detections(tracks)
179
+ else:
180
+ detections = [
181
+ FaceDetection(
182
+ track_id=i,
183
+ bbox=boxes[i],
184
+ score=scores[i],
185
+ landmarks=landmarks[i] if landmarks is not None else None,
186
+ )
187
+ for i in range(len(boxes))
188
+ ]
189
+
190
+ # Temporal smoothing
191
+ if self.use_smoothing and self.smoother:
192
+ active_ids = set()
193
+ for det in detections:
194
+ det.bbox, det.score = self.smoother.smooth(
195
+ det.track_id, det.bbox, det.score
196
+ )
197
+ active_ids.add(det.track_id)
198
+ self.smoother.cleanup(active_ids)
199
+
200
+ self._last_detections = detections
201
+ return detections
202
+
203
+ def process_video(self, source: Union[str, int],
204
+ callback: Optional[Callable] = None,
205
+ max_frames: int = -1,
206
+ output_path: Optional[str] = None,
207
+ show: bool = False) -> Dict:
208
+ """
209
+ Process a video file or stream.
210
+
211
+ Args:
212
+ source: Video file path, webcam index (0), or RTSP URL
213
+ callback: Optional per-frame callback(frame, detections, frame_idx)
214
+ max_frames: Max frames to process (-1 for all)
215
+ output_path: Save annotated video to this path
216
+ show: Display annotated frames in window
217
+
218
+ Returns:
219
+ dict with 'total_frames', 'avg_fps', 'avg_faces_per_frame'
220
+ """
221
+ cap = cv2.VideoCapture(source)
222
+ if not cap.isOpened():
223
+ raise IOError(f"Cannot open video source: {source}")
224
+
225
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
226
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
227
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
228
+
229
+ writer = None
230
+ if output_path:
231
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
232
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
233
+
234
+ total_frames = 0
235
+ total_faces = 0
236
+ total_time = 0
237
+
238
+ try:
239
+ while True:
240
+ ret, frame = cap.read()
241
+ if not ret or (max_frames > 0 and total_frames >= max_frames):
242
+ break
243
+
244
+ t0 = time.time()
245
+ detections = self.detect_frame(frame)
246
+ dt = time.time() - t0
247
+
248
+ total_frames += 1
249
+ total_faces += len(detections)
250
+ total_time += dt
251
+
252
+ if callback:
253
+ callback(frame, detections, total_frames)
254
+
255
+ # Draw detections
256
+ annotated = self._draw_detections(frame, detections)
257
+
258
+ if writer:
259
+ writer.write(annotated)
260
+
261
+ if show:
262
+ cv2.imshow('FaceDet', annotated)
263
+ if cv2.waitKey(1) & 0xFF == ord('q'):
264
+ break
265
+
266
+ finally:
267
+ cap.release()
268
+ if writer:
269
+ writer.release()
270
+ if show:
271
+ cv2.destroyAllWindows()
272
+
273
+ avg_fps = total_frames / max(total_time, 1e-6)
274
+ avg_faces = total_faces / max(total_frames, 1)
275
+
276
+ stats = {
277
+ 'total_frames': total_frames,
278
+ 'avg_fps': avg_fps,
279
+ 'avg_faces_per_frame': avg_faces,
280
+ 'total_time': total_time,
281
+ }
282
+ print(f"[VideoFaceDetector] {total_frames} frames, "
283
+ f"{avg_fps:.1f} FPS, {avg_faces:.1f} faces/frame")
284
+ return stats
285
+
286
+ def _preprocess(self, image: np.ndarray):
287
+ """Resize + pad + normalize for model input."""
288
+ h, w = image.shape[:2]
289
+ scale = self.input_size / max(h, w)
290
+ new_h, new_w = int(h * scale), int(w * scale)
291
+ resized = cv2.resize(image, (new_w, new_h))
292
+
293
+ # Pad to input_size
294
+ padded = np.zeros((self.input_size, self.input_size, 3), dtype=np.float32)
295
+ padded[:new_h, :new_w] = resized
296
+
297
+ # Normalize (mean subtraction)
298
+ padded = padded - self.mean
299
+
300
+ # HWC → CHW
301
+ padded = padded.transpose(2, 0, 1)
302
+
303
+ pad = (0, 0) # (pad_x, pad_y) = 0 since we place image at top-left
304
+ return padded, scale, pad
305
+
306
+ def _infer_pytorch(self, img: np.ndarray):
307
+ """Run PyTorch inference."""
308
+ tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
309
+ results = self.model(tensor, targets=None)
310
+ r = results[0]
311
+ boxes = r['boxes'].cpu().numpy()
312
+ scores = r['scores'].cpu().numpy()
313
+ landmarks = r.get('landmarks', None)
314
+ if landmarks is not None:
315
+ landmarks = landmarks.cpu().numpy()
316
+ return boxes, scores, landmarks
317
+
318
+ def _infer_onnx(self, img: np.ndarray):
319
+ """Run ONNX inference."""
320
+ inputs = {self.onnx_session.get_inputs()[0].name: img[np.newaxis].astype(np.float32)}
321
+ outputs = self.onnx_session.run(None, inputs)
322
+ # ONNX output format depends on export — handle common patterns
323
+ if len(outputs) >= 2:
324
+ boxes = outputs[0]
325
+ scores = outputs[1]
326
+ landmarks = outputs[2] if len(outputs) > 2 else None
327
+ return boxes, scores, landmarks
328
+ return np.empty((0, 4)), np.empty(0), None
329
+
330
+ def _tracks_to_detections(self, tracks: list) -> List[FaceDetection]:
331
+ """Convert Track objects to FaceDetection objects."""
332
+ return [
333
+ FaceDetection(
334
+ track_id=t.track_id,
335
+ bbox=t.bbox,
336
+ score=t.score,
337
+ is_confirmed=t.is_confirmed,
338
+ landmarks=t.landmarks,
339
+ )
340
+ for t in tracks
341
+ ]
342
+
343
+ @staticmethod
344
+ def _draw_detections(frame: np.ndarray, detections: List[FaceDetection]) -> np.ndarray:
345
+ """Draw bounding boxes and track IDs on frame."""
346
+ annotated = frame.copy()
347
+ for det in detections:
348
+ x1, y1, x2, y2 = det.bbox.astype(int)
349
+ color = (0, 255, 0) if det.is_confirmed else (0, 255, 255)
350
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
351
+ label = f"ID:{det.track_id} {det.score:.2f}"
352
+ cv2.putText(annotated, label, (x1, y1 - 5),
353
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
354
+
355
+ # Draw landmarks
356
+ if det.landmarks is not None and len(det.landmarks) >= 10:
357
+ for i in range(5):
358
+ x = int(det.landmarks[i * 2])
359
+ y = int(det.landmarks[i * 2 + 1])
360
+ if x > 0 and y > 0:
361
+ cv2.circle(annotated, (x, y), 2, (0, 0, 255), -1)
362
+
363
+ return annotated
364
+
365
+ def reset(self):
366
+ """Reset tracker and smoother state (for new video)."""
367
+ if self.tracker:
368
+ self.tracker.reset()
369
+ if self.smoother:
370
+ self.smoother.states.clear()
371
+ self._frame_count = 0
372
+ self._last_detections = []