| """ |
| Video Face Detector — End-to-end video inference with tracking and smoothing. |
| |
| Combines: |
| 1. SCRFD detector (per-frame face detection) |
| 2. ByteTrack tracker (cross-frame identity association) |
| 3. Temporal smoother (jitter reduction) |
| 4. Optional keyframe strategy (run full detection every N frames, |
| track-only on intermediate frames for speed) |
| |
| Supports: |
| - Live webcam streams |
| - Video files (MP4, AVI, etc.) |
| - RTSP/RTMP streams |
| - Image directory sequences |
| - ONNX runtime for deployment |
| """ |
|
|
| import os |
| import time |
| import numpy as np |
| import cv2 |
| from typing import Optional, Callable, List, Dict, Union |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from .tracker import ByteTracker, Track |
| from .temporal import TemporalSmoother |
|
|
|
|
| @dataclass |
| class FaceDetection: |
| """Single face detection result.""" |
| track_id: int |
| bbox: np.ndarray |
| score: float |
| landmarks: Optional[np.ndarray] = None |
| is_confirmed: bool = True |
|
|
|
|
| class VideoFaceDetector: |
| """ |
| Production video face detection pipeline. |
| |
| Usage: |
| detector = VideoFaceDetector(model_path='scrfd_34g.pth', model_name='scrfd_34g') |
| for result in detector.process_video('input.mp4'): |
| for face in result['faces']: |
| print(f"Track {face.track_id}: bbox={face.bbox}, score={face.score:.2f}") |
| |
| Args: |
| model: SCRFD model instance (or None to use ONNX) |
| model_path: Path to checkpoint (.pth) or ONNX model (.onnx) |
| model_name: Model variant name for building from scratch |
| device: 'cuda' or 'cpu' |
| score_threshold: Min detection confidence |
| nms_threshold: NMS IoU threshold |
| input_size: Model input resolution |
| use_tracking: Enable ByteTrack temporal tracking |
| use_smoothing: Enable EMA temporal smoothing |
| keyframe_interval: Run full detection every N frames (0=every frame) |
| """ |
|
|
| def __init__(self, |
| model=None, |
| model_path: Optional[str] = None, |
| model_name: str = 'scrfd_34g', |
| device: str = 'cuda', |
| score_threshold: float = 0.3, |
| nms_threshold: float = 0.4, |
| input_size: int = 640, |
| use_tracking: bool = True, |
| use_smoothing: bool = True, |
| keyframe_interval: int = 0): |
|
|
| self.device = device |
| self.input_size = input_size |
| self.score_threshold = score_threshold |
| self.use_tracking = use_tracking |
| self.use_smoothing = use_smoothing |
| self.keyframe_interval = keyframe_interval |
| self.mean = np.array([104.0, 117.0, 123.0], dtype=np.float32) |
|
|
| |
| self.onnx_session = None |
| if model is not None: |
| self.model = model |
| elif model_path and model_path.endswith('.onnx'): |
| self._load_onnx(model_path) |
| self.model = None |
| else: |
| from models.detector import build_detector |
| self.model = build_detector( |
| model_name, |
| score_threshold=score_threshold, |
| nms_threshold=nms_threshold, |
| ) |
| if model_path: |
| checkpoint = torch.load(model_path, map_location='cpu') |
| state_dict = checkpoint.get('model_state_dict', checkpoint) |
| self.model.load_state_dict(state_dict, strict=False) |
|
|
| self.model.to(device) |
| self.model.eval() |
|
|
| |
| self.tracker = ByteTracker() if use_tracking else None |
| self.smoother = TemporalSmoother() if use_smoothing else None |
|
|
| self._frame_count = 0 |
| self._last_detections = [] |
|
|
| def _load_onnx(self, model_path: str): |
| """Load ONNX model for deployment inference.""" |
| try: |
| import onnxruntime as ort |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| if self.device == 'cpu': |
| providers = ['CPUExecutionProvider'] |
| self.onnx_session = ort.InferenceSession(model_path, providers=providers) |
| except ImportError: |
| raise ImportError("onnxruntime required for ONNX inference: pip install onnxruntime-gpu") |
|
|
| @torch.no_grad() |
| def detect_frame(self, frame: np.ndarray) -> List[FaceDetection]: |
| """ |
| Detect faces in a single frame. |
| |
| Args: |
| frame: BGR image (OpenCV format) or RGB numpy array |
| |
| Returns: |
| List of FaceDetection objects |
| """ |
| self._frame_count += 1 |
|
|
| |
| if (self.keyframe_interval > 0 and |
| self._frame_count % self.keyframe_interval != 1 and |
| self._frame_count > 1): |
| |
| if self.tracker: |
| tracks = self.tracker.update( |
| np.empty((0, 4)), np.empty(0), None |
| ) |
| return self._tracks_to_detections(tracks) |
| return self._last_detections |
|
|
| |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if frame.shape[2] == 3 else frame |
| h_orig, w_orig = rgb.shape[:2] |
| img, scale, pad = self._preprocess(rgb) |
|
|
| |
| if self.onnx_session: |
| boxes, scores, landmarks = self._infer_onnx(img) |
| else: |
| boxes, scores, landmarks = self._infer_pytorch(img) |
|
|
| |
| if len(boxes) > 0: |
| boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad[0]) / scale |
| boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad[1]) / scale |
| boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, w_orig) |
| boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, h_orig) |
|
|
| if landmarks is not None and len(landmarks) > 0: |
| for i in range(5): |
| landmarks[:, i*2] = (landmarks[:, i*2] - pad[0]) / scale |
| landmarks[:, i*2+1] = (landmarks[:, i*2+1] - pad[1]) / scale |
|
|
| |
| if self.use_tracking and self.tracker: |
| lmk = landmarks if landmarks is not None else None |
| tracks = self.tracker.update(boxes, scores, lmk) |
| detections = self._tracks_to_detections(tracks) |
| else: |
| detections = [ |
| FaceDetection( |
| track_id=i, |
| bbox=boxes[i], |
| score=scores[i], |
| landmarks=landmarks[i] if landmarks is not None else None, |
| ) |
| for i in range(len(boxes)) |
| ] |
|
|
| |
| if self.use_smoothing and self.smoother: |
| active_ids = set() |
| for det in detections: |
| det.bbox, det.score = self.smoother.smooth( |
| det.track_id, det.bbox, det.score |
| ) |
| active_ids.add(det.track_id) |
| self.smoother.cleanup(active_ids) |
|
|
| self._last_detections = detections |
| return detections |
|
|
| def process_video(self, source: Union[str, int], |
| callback: Optional[Callable] = None, |
| max_frames: int = -1, |
| output_path: Optional[str] = None, |
| show: bool = False) -> Dict: |
| """ |
| Process a video file or stream. |
| |
| Args: |
| source: Video file path, webcam index (0), or RTSP URL |
| callback: Optional per-frame callback(frame, detections, frame_idx) |
| max_frames: Max frames to process (-1 for all) |
| output_path: Save annotated video to this path |
| show: Display annotated frames in window |
| |
| Returns: |
| dict with 'total_frames', 'avg_fps', 'avg_faces_per_frame' |
| """ |
| cap = cv2.VideoCapture(source) |
| if not cap.isOpened(): |
| raise IOError(f"Cannot open video source: {source}") |
|
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| writer = None |
| if output_path: |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
| total_frames = 0 |
| total_faces = 0 |
| total_time = 0 |
|
|
| try: |
| while True: |
| ret, frame = cap.read() |
| if not ret or (max_frames > 0 and total_frames >= max_frames): |
| break |
|
|
| t0 = time.time() |
| detections = self.detect_frame(frame) |
| dt = time.time() - t0 |
|
|
| total_frames += 1 |
| total_faces += len(detections) |
| total_time += dt |
|
|
| if callback: |
| callback(frame, detections, total_frames) |
|
|
| |
| annotated = self._draw_detections(frame, detections) |
|
|
| if writer: |
| writer.write(annotated) |
|
|
| if show: |
| cv2.imshow('FaceDet', annotated) |
| if cv2.waitKey(1) & 0xFF == ord('q'): |
| break |
|
|
| finally: |
| cap.release() |
| if writer: |
| writer.release() |
| if show: |
| cv2.destroyAllWindows() |
|
|
| avg_fps = total_frames / max(total_time, 1e-6) |
| avg_faces = total_faces / max(total_frames, 1) |
|
|
| stats = { |
| 'total_frames': total_frames, |
| 'avg_fps': avg_fps, |
| 'avg_faces_per_frame': avg_faces, |
| 'total_time': total_time, |
| } |
| print(f"[VideoFaceDetector] {total_frames} frames, " |
| f"{avg_fps:.1f} FPS, {avg_faces:.1f} faces/frame") |
| return stats |
|
|
| def _preprocess(self, image: np.ndarray): |
| """Resize + pad + normalize for model input.""" |
| h, w = image.shape[:2] |
| scale = self.input_size / max(h, w) |
| new_h, new_w = int(h * scale), int(w * scale) |
| resized = cv2.resize(image, (new_w, new_h)) |
|
|
| |
| padded = np.zeros((self.input_size, self.input_size, 3), dtype=np.float32) |
| padded[:new_h, :new_w] = resized |
|
|
| |
| padded = padded - self.mean |
|
|
| |
| padded = padded.transpose(2, 0, 1) |
|
|
| pad = (0, 0) |
| return padded, scale, pad |
|
|
| def _infer_pytorch(self, img: np.ndarray): |
| """Run PyTorch inference.""" |
| tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device) |
| results = self.model(tensor, targets=None) |
| r = results[0] |
| boxes = r['boxes'].cpu().numpy() |
| scores = r['scores'].cpu().numpy() |
| landmarks = r.get('landmarks', None) |
| if landmarks is not None: |
| landmarks = landmarks.cpu().numpy() |
| return boxes, scores, landmarks |
|
|
| def _infer_onnx(self, img: np.ndarray): |
| """Run ONNX inference.""" |
| inputs = {self.onnx_session.get_inputs()[0].name: img[np.newaxis].astype(np.float32)} |
| outputs = self.onnx_session.run(None, inputs) |
| |
| if len(outputs) >= 2: |
| boxes = outputs[0] |
| scores = outputs[1] |
| landmarks = outputs[2] if len(outputs) > 2 else None |
| return boxes, scores, landmarks |
| return np.empty((0, 4)), np.empty(0), None |
|
|
| def _tracks_to_detections(self, tracks: list) -> List[FaceDetection]: |
| """Convert Track objects to FaceDetection objects.""" |
| return [ |
| FaceDetection( |
| track_id=t.track_id, |
| bbox=t.bbox, |
| score=t.score, |
| is_confirmed=t.is_confirmed, |
| landmarks=t.landmarks, |
| ) |
| for t in tracks |
| ] |
|
|
| @staticmethod |
| def _draw_detections(frame: np.ndarray, detections: List[FaceDetection]) -> np.ndarray: |
| """Draw bounding boxes and track IDs on frame.""" |
| annotated = frame.copy() |
| for det in detections: |
| x1, y1, x2, y2 = det.bbox.astype(int) |
| color = (0, 255, 0) if det.is_confirmed else (0, 255, 255) |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) |
| label = f"ID:{det.track_id} {det.score:.2f}" |
| cv2.putText(annotated, label, (x1, y1 - 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) |
|
|
| |
| if det.landmarks is not None and len(det.landmarks) >= 10: |
| for i in range(5): |
| x = int(det.landmarks[i * 2]) |
| y = int(det.landmarks[i * 2 + 1]) |
| if x > 0 and y > 0: |
| cv2.circle(annotated, (x, y), 2, (0, 0, 255), -1) |
|
|
| return annotated |
|
|
| def reset(self): |
| """Reset tracker and smoother state (for new video).""" |
| if self.tracker: |
| self.tracker.reset() |
| if self.smoother: |
| self.smoother.states.clear() |
| self._frame_count = 0 |
| self._last_detections = [] |
|
|