File size: 13,295 Bytes
afda97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""
Video Face Detector — End-to-end video inference with tracking and smoothing.

Combines:
1. SCRFD detector (per-frame face detection)
2. ByteTrack tracker (cross-frame identity association)
3. Temporal smoother (jitter reduction)
4. Optional keyframe strategy (run full detection every N frames,
   track-only on intermediate frames for speed)

Supports:
- Live webcam streams
- Video files (MP4, AVI, etc.)
- RTSP/RTMP streams
- Image directory sequences
- ONNX runtime for deployment
"""

import os
import time
import numpy as np
import cv2
from typing import Optional, Callable, List, Dict, Union
from dataclasses import dataclass

import torch
import torch.nn.functional as F

from .tracker import ByteTracker, Track
from .temporal import TemporalSmoother


@dataclass
class FaceDetection:
    """Single face detection result."""
    track_id: int
    bbox: np.ndarray          # [x1, y1, x2, y2]
    score: float
    landmarks: Optional[np.ndarray] = None  # [10] = 5 x (x, y)
    is_confirmed: bool = True


class VideoFaceDetector:
    """
    Production video face detection pipeline.

    Usage:
        detector = VideoFaceDetector(model_path='scrfd_34g.pth', model_name='scrfd_34g')
        for result in detector.process_video('input.mp4'):
            for face in result['faces']:
                print(f"Track {face.track_id}: bbox={face.bbox}, score={face.score:.2f}")

    Args:
        model: SCRFD model instance (or None to use ONNX)
        model_path: Path to checkpoint (.pth) or ONNX model (.onnx)
        model_name: Model variant name for building from scratch
        device: 'cuda' or 'cpu'
        score_threshold: Min detection confidence
        nms_threshold: NMS IoU threshold
        input_size: Model input resolution
        use_tracking: Enable ByteTrack temporal tracking
        use_smoothing: Enable EMA temporal smoothing
        keyframe_interval: Run full detection every N frames (0=every frame)
    """

    def __init__(self,
                 model=None,
                 model_path: Optional[str] = None,
                 model_name: str = 'scrfd_34g',
                 device: str = 'cuda',
                 score_threshold: float = 0.3,
                 nms_threshold: float = 0.4,
                 input_size: int = 640,
                 use_tracking: bool = True,
                 use_smoothing: bool = True,
                 keyframe_interval: int = 0):

        self.device = device
        self.input_size = input_size
        self.score_threshold = score_threshold
        self.use_tracking = use_tracking
        self.use_smoothing = use_smoothing
        self.keyframe_interval = keyframe_interval
        self.mean = np.array([104.0, 117.0, 123.0], dtype=np.float32)

        # Load model
        self.onnx_session = None
        if model is not None:
            self.model = model
        elif model_path and model_path.endswith('.onnx'):
            self._load_onnx(model_path)
            self.model = None
        else:
            from models.detector import build_detector
            self.model = build_detector(
                model_name,
                score_threshold=score_threshold,
                nms_threshold=nms_threshold,
            )
            if model_path:
                checkpoint = torch.load(model_path, map_location='cpu')
                state_dict = checkpoint.get('model_state_dict', checkpoint)
                self.model.load_state_dict(state_dict, strict=False)

            self.model.to(device)
            self.model.eval()

        # Initialize tracker and smoother
        self.tracker = ByteTracker() if use_tracking else None
        self.smoother = TemporalSmoother() if use_smoothing else None

        self._frame_count = 0
        self._last_detections = []

    def _load_onnx(self, model_path: str):
        """Load ONNX model for deployment inference."""
        try:
            import onnxruntime as ort
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
            if self.device == 'cpu':
                providers = ['CPUExecutionProvider']
            self.onnx_session = ort.InferenceSession(model_path, providers=providers)
        except ImportError:
            raise ImportError("onnxruntime required for ONNX inference: pip install onnxruntime-gpu")

    @torch.no_grad()
    def detect_frame(self, frame: np.ndarray) -> List[FaceDetection]:
        """
        Detect faces in a single frame.

        Args:
            frame: BGR image (OpenCV format) or RGB numpy array

        Returns:
            List of FaceDetection objects
        """
        self._frame_count += 1

        # Keyframe strategy: skip detection on non-keyframes
        if (self.keyframe_interval > 0 and
            self._frame_count % self.keyframe_interval != 1 and
            self._frame_count > 1):
            # Use tracker prediction only
            if self.tracker:
                tracks = self.tracker.update(
                    np.empty((0, 4)), np.empty(0), None
                )
                return self._tracks_to_detections(tracks)
            return self._last_detections

        # Preprocess
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if frame.shape[2] == 3 else frame
        h_orig, w_orig = rgb.shape[:2]
        img, scale, pad = self._preprocess(rgb)

        # Run detection
        if self.onnx_session:
            boxes, scores, landmarks = self._infer_onnx(img)
        else:
            boxes, scores, landmarks = self._infer_pytorch(img)

        # Rescale to original image coordinates
        if len(boxes) > 0:
            boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad[0]) / scale
            boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad[1]) / scale
            boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, w_orig)
            boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, h_orig)

            if landmarks is not None and len(landmarks) > 0:
                for i in range(5):
                    landmarks[:, i*2] = (landmarks[:, i*2] - pad[0]) / scale
                    landmarks[:, i*2+1] = (landmarks[:, i*2+1] - pad[1]) / scale

        # Tracking
        if self.use_tracking and self.tracker:
            lmk = landmarks if landmarks is not None else None
            tracks = self.tracker.update(boxes, scores, lmk)
            detections = self._tracks_to_detections(tracks)
        else:
            detections = [
                FaceDetection(
                    track_id=i,
                    bbox=boxes[i],
                    score=scores[i],
                    landmarks=landmarks[i] if landmarks is not None else None,
                )
                for i in range(len(boxes))
            ]

        # Temporal smoothing
        if self.use_smoothing and self.smoother:
            active_ids = set()
            for det in detections:
                det.bbox, det.score = self.smoother.smooth(
                    det.track_id, det.bbox, det.score
                )
                active_ids.add(det.track_id)
            self.smoother.cleanup(active_ids)

        self._last_detections = detections
        return detections

    def process_video(self, source: Union[str, int],
                      callback: Optional[Callable] = None,
                      max_frames: int = -1,
                      output_path: Optional[str] = None,
                      show: bool = False) -> Dict:
        """
        Process a video file or stream.

        Args:
            source: Video file path, webcam index (0), or RTSP URL
            callback: Optional per-frame callback(frame, detections, frame_idx)
            max_frames: Max frames to process (-1 for all)
            output_path: Save annotated video to this path
            show: Display annotated frames in window

        Returns:
            dict with 'total_frames', 'avg_fps', 'avg_faces_per_frame'
        """
        cap = cv2.VideoCapture(source)
        if not cap.isOpened():
            raise IOError(f"Cannot open video source: {source}")

        fps = cap.get(cv2.CAP_PROP_FPS) or 30
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        writer = None
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        total_frames = 0
        total_faces = 0
        total_time = 0

        try:
            while True:
                ret, frame = cap.read()
                if not ret or (max_frames > 0 and total_frames >= max_frames):
                    break

                t0 = time.time()
                detections = self.detect_frame(frame)
                dt = time.time() - t0

                total_frames += 1
                total_faces += len(detections)
                total_time += dt

                if callback:
                    callback(frame, detections, total_frames)

                # Draw detections
                annotated = self._draw_detections(frame, detections)

                if writer:
                    writer.write(annotated)

                if show:
                    cv2.imshow('FaceDet', annotated)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break

        finally:
            cap.release()
            if writer:
                writer.release()
            if show:
                cv2.destroyAllWindows()

        avg_fps = total_frames / max(total_time, 1e-6)
        avg_faces = total_faces / max(total_frames, 1)

        stats = {
            'total_frames': total_frames,
            'avg_fps': avg_fps,
            'avg_faces_per_frame': avg_faces,
            'total_time': total_time,
        }
        print(f"[VideoFaceDetector] {total_frames} frames, "
              f"{avg_fps:.1f} FPS, {avg_faces:.1f} faces/frame")
        return stats

    def _preprocess(self, image: np.ndarray):
        """Resize + pad + normalize for model input."""
        h, w = image.shape[:2]
        scale = self.input_size / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        resized = cv2.resize(image, (new_w, new_h))

        # Pad to input_size
        padded = np.zeros((self.input_size, self.input_size, 3), dtype=np.float32)
        padded[:new_h, :new_w] = resized

        # Normalize (mean subtraction)
        padded = padded - self.mean

        # HWC → CHW
        padded = padded.transpose(2, 0, 1)

        pad = (0, 0)  # (pad_x, pad_y) = 0 since we place image at top-left
        return padded, scale, pad

    def _infer_pytorch(self, img: np.ndarray):
        """Run PyTorch inference."""
        tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
        results = self.model(tensor, targets=None)
        r = results[0]
        boxes = r['boxes'].cpu().numpy()
        scores = r['scores'].cpu().numpy()
        landmarks = r.get('landmarks', None)
        if landmarks is not None:
            landmarks = landmarks.cpu().numpy()
        return boxes, scores, landmarks

    def _infer_onnx(self, img: np.ndarray):
        """Run ONNX inference."""
        inputs = {self.onnx_session.get_inputs()[0].name: img[np.newaxis].astype(np.float32)}
        outputs = self.onnx_session.run(None, inputs)
        # ONNX output format depends on export — handle common patterns
        if len(outputs) >= 2:
            boxes = outputs[0]
            scores = outputs[1]
            landmarks = outputs[2] if len(outputs) > 2 else None
            return boxes, scores, landmarks
        return np.empty((0, 4)), np.empty(0), None

    def _tracks_to_detections(self, tracks: list) -> List[FaceDetection]:
        """Convert Track objects to FaceDetection objects."""
        return [
            FaceDetection(
                track_id=t.track_id,
                bbox=t.bbox,
                score=t.score,
                is_confirmed=t.is_confirmed,
                landmarks=t.landmarks,
            )
            for t in tracks
        ]

    @staticmethod
    def _draw_detections(frame: np.ndarray, detections: List[FaceDetection]) -> np.ndarray:
        """Draw bounding boxes and track IDs on frame."""
        annotated = frame.copy()
        for det in detections:
            x1, y1, x2, y2 = det.bbox.astype(int)
            color = (0, 255, 0) if det.is_confirmed else (0, 255, 255)
            cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
            label = f"ID:{det.track_id} {det.score:.2f}"
            cv2.putText(annotated, label, (x1, y1 - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

            # Draw landmarks
            if det.landmarks is not None and len(det.landmarks) >= 10:
                for i in range(5):
                    x = int(det.landmarks[i * 2])
                    y = int(det.landmarks[i * 2 + 1])
                    if x > 0 and y > 0:
                        cv2.circle(annotated, (x, y), 2, (0, 0, 255), -1)

        return annotated

    def reset(self):
        """Reset tracker and smoother state (for new video)."""
        if self.tracker:
            self.tracker.reset()
        if self.smoother:
            self.smoother.states.clear()
        self._frame_count = 0
        self._last_detections = []