File size: 3,586 Bytes
307c3fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""
Video face detection with tracking and temporal smoothing.

Usage:
    # Process video file
    python scripts/detect_video.py \\
        --model scrfd_34g \\
        --checkpoint checkpoints/scrfd_34g_best.pth \\
        --input video.mp4 \\
        --output output.mp4

    # Webcam (real-time)
    python scripts/detect_video.py \\
        --model scrfd_2.5g \\
        --checkpoint checkpoints/scrfd_2.5g_best.pth \\
        --input 0 \\
        --show

    # RTSP stream
    python scripts/detect_video.py \\
        --model scrfd_2.5g \\
        --checkpoint checkpoints/scrfd_2.5g_best.pth \\
        --input rtsp://192.168.1.100/stream \\
        --output output.mp4
"""

import os
import sys
import argparse
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from models.detector import build_detector
from engine.video_detector import VideoFaceDetector


def parse_args():
    parser = argparse.ArgumentParser(description='Video Face Detection')
    parser.add_argument('--model', type=str, default='scrfd_34g')
    parser.add_argument('--checkpoint', type=str, required=True)
    parser.add_argument('--input', type=str, required=True,
                        help='Video file, webcam index (0), or RTSP URL')
    parser.add_argument('--output', type=str, default=None,
                        help='Output video path')
    parser.add_argument('--input-size', type=int, default=640)
    parser.add_argument('--score-thresh', type=float, default=0.4)
    parser.add_argument('--nms-thresh', type=float, default=0.4)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--no-tracking', action='store_true')
    parser.add_argument('--no-smoothing', action='store_true')
    parser.add_argument('--keyframe-interval', type=int, default=0,
                        help='Run detector every N frames (0=every frame)')
    parser.add_argument('--max-frames', type=int, default=-1)
    parser.add_argument('--show', action='store_true',
                        help='Display video in window')
    parser.add_argument('--onnx', type=str, default=None,
                        help='Use ONNX model instead of PyTorch')
    return parser.parse_args()


def main():
    args = parse_args()

    # Build model
    if args.onnx:
        model = None
        model_path = args.onnx
    else:
        model = build_detector(
            args.model,
            score_threshold=args.score_thresh,
            nms_threshold=args.nms_thresh,
        )
        model_path = args.checkpoint

    # Build video detector
    detector = VideoFaceDetector(
        model=model,
        model_path=model_path,
        model_name=args.model,
        device=args.device,
        score_threshold=args.score_thresh,
        nms_threshold=args.nms_thresh,
        input_size=args.input_size,
        use_tracking=not args.no_tracking,
        use_smoothing=not args.no_smoothing,
        keyframe_interval=args.keyframe_interval,
    )

    # Process video
    source = int(args.input) if args.input.isdigit() else args.input
    stats = detector.process_video(
        source=source,
        max_frames=args.max_frames,
        output_path=args.output,
        show=args.show,
    )

    print(f"\nProcessing complete:")
    print(f"  Frames: {stats['total_frames']}")
    print(f"  Average FPS: {stats['avg_fps']:.1f}")
    print(f"  Average faces/frame: {stats['avg_faces_per_frame']:.1f}")
    if args.output:
        print(f"  Output saved to: {args.output}")


if __name__ == '__main__':
    main()