SID2000 commited on
Commit
488e9e1
·
verified ·
1 Parent(s): 8965444

Upload live_capture.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. live_capture.py +208 -0
live_capture.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Media capture sources for live brain prediction.
2
+
3
+ Provides webcam, screen capture, and file streaming sources that
4
+ yield frames at a controlled rate for real-time inference.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ import threading
11
+ import logging
12
+ from pathlib import Path
13
+ from collections import deque
14
+ from dataclasses import dataclass
15
+
16
+ import numpy as np
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class MediaFrame:
23
+ """A single frame from any media source."""
24
+ video_frame: np.ndarray | None = None # (H, W, 3) RGB
25
+ audio_chunk: np.ndarray | None = None # (samples,) float32
26
+ timestamp: float = 0.0
27
+
28
+
29
+ class BaseCapture:
30
+ """Base class for media capture sources."""
31
+
32
+ def __init__(self, fps: float = 1.0):
33
+ self.fps = fps
34
+ self._running = False
35
+ self._buffer: deque[MediaFrame] = deque(maxlen=300)
36
+ self._thread: threading.Thread | None = None
37
+ self._lock = threading.Lock()
38
+
39
+ def start(self):
40
+ self._running = True
41
+ self._thread = threading.Thread(target=self._capture_loop, daemon=True)
42
+ self._thread.start()
43
+
44
+ def stop(self):
45
+ self._running = False
46
+ if self._thread:
47
+ self._thread.join(timeout=3.0)
48
+
49
+ def get_latest_frame(self) -> MediaFrame | None:
50
+ with self._lock:
51
+ return self._buffer[-1] if self._buffer else None
52
+
53
+ def get_all_frames(self) -> list[MediaFrame]:
54
+ with self._lock:
55
+ frames = list(self._buffer)
56
+ return frames
57
+
58
+ @property
59
+ def is_running(self) -> bool:
60
+ return self._running
61
+
62
+ @property
63
+ def frame_count(self) -> int:
64
+ return len(self._buffer)
65
+
66
+ def _capture_loop(self):
67
+ raise NotImplementedError
68
+
69
+
70
+ class WebcamCapture(BaseCapture):
71
+ """Capture frames from webcam using OpenCV."""
72
+
73
+ def __init__(self, camera_index: int = 0, fps: float = 1.0, resolution: tuple = (640, 480)):
74
+ super().__init__(fps)
75
+ self.camera_index = camera_index
76
+ self.resolution = resolution
77
+
78
+ def _capture_loop(self):
79
+ try:
80
+ import cv2
81
+ except ImportError:
82
+ logger.error("OpenCV not installed. Run: pip install opencv-python")
83
+ self._running = False
84
+ return
85
+
86
+ cap = cv2.VideoCapture(self.camera_index)
87
+ if not cap.isOpened():
88
+ logger.error(f"Cannot open camera {self.camera_index}")
89
+ self._running = False
90
+ return
91
+
92
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0])
93
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1])
94
+ start_time = time.time()
95
+ interval = 1.0 / self.fps
96
+
97
+ try:
98
+ while self._running:
99
+ ret, frame = cap.read()
100
+ if not ret:
101
+ break
102
+ # BGR -> RGB
103
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
104
+ media_frame = MediaFrame(
105
+ video_frame=frame_rgb,
106
+ timestamp=time.time() - start_time,
107
+ )
108
+ with self._lock:
109
+ self._buffer.append(media_frame)
110
+ time.sleep(interval)
111
+ finally:
112
+ cap.release()
113
+
114
+
115
+ class ScreenCapture(BaseCapture):
116
+ """Capture screen frames using mss."""
117
+
118
+ def __init__(self, fps: float = 1.0, region: dict | None = None):
119
+ super().__init__(fps)
120
+ self.region = region # {"left": 0, "top": 0, "width": 1920, "height": 1080}
121
+
122
+ def _capture_loop(self):
123
+ try:
124
+ import mss
125
+ from PIL import Image
126
+ except ImportError:
127
+ logger.error("mss/PIL not installed. Run: pip install mss Pillow")
128
+ self._running = False
129
+ return
130
+
131
+ start_time = time.time()
132
+ interval = 1.0 / self.fps
133
+
134
+ with mss.mss() as sct:
135
+ monitor = self.region or sct.monitors[1] # Primary monitor
136
+ while self._running:
137
+ screenshot = sct.grab(monitor)
138
+ img = Image.frombytes("RGB", screenshot.size, screenshot.bgra, "raw", "BGRX")
139
+ frame = np.array(img)
140
+ media_frame = MediaFrame(
141
+ video_frame=frame,
142
+ timestamp=time.time() - start_time,
143
+ )
144
+ with self._lock:
145
+ self._buffer.append(media_frame)
146
+ time.sleep(interval)
147
+
148
+
149
+ class FileStreamer(BaseCapture):
150
+ """Stream a video file frame-by-frame at real-time speed."""
151
+
152
+ def __init__(self, file_path: str, fps: float = 1.0):
153
+ super().__init__(fps)
154
+ self.file_path = file_path
155
+
156
+ def _capture_loop(self):
157
+ try:
158
+ import cv2
159
+ except ImportError:
160
+ logger.error("OpenCV not installed. Run: pip install opencv-python")
161
+ self._running = False
162
+ return
163
+
164
+ cap = cv2.VideoCapture(self.file_path)
165
+ if not cap.isOpened():
166
+ logger.error(f"Cannot open video: {self.file_path}")
167
+ self._running = False
168
+ return
169
+
170
+ video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
171
+ # Skip frames to match our target FPS
172
+ frame_skip = max(1, int(video_fps / self.fps))
173
+ frame_idx = 0
174
+ start_time = time.time()
175
+ interval = 1.0 / self.fps
176
+
177
+ try:
178
+ while self._running:
179
+ ret, frame = cap.read()
180
+ if not ret:
181
+ self._running = False
182
+ break
183
+ frame_idx += 1
184
+ if frame_idx % frame_skip != 0:
185
+ continue
186
+
187
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
188
+ media_frame = MediaFrame(
189
+ video_frame=frame_rgb,
190
+ timestamp=time.time() - start_time,
191
+ )
192
+ with self._lock:
193
+ self._buffer.append(media_frame)
194
+ time.sleep(interval)
195
+ finally:
196
+ cap.release()
197
+
198
+
199
+ def get_capture_source(source_type: str, **kwargs) -> BaseCapture:
200
+ """Factory function to create a capture source."""
201
+ sources = {
202
+ "webcam": WebcamCapture,
203
+ "screen": ScreenCapture,
204
+ "file": FileStreamer,
205
+ }
206
+ if source_type not in sources:
207
+ raise ValueError(f"Unknown source: {source_type}. Choose from {list(sources)}")
208
+ return sources[source_type](**kwargs)