Spaces:
Running
Running
| """Media capture sources for live brain prediction. | |
| Provides webcam, screen capture, and file streaming sources that | |
| yield frames at a controlled rate for real-time inference. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import threading | |
| import logging | |
| from pathlib import Path | |
| from collections import deque | |
| from dataclasses import dataclass | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class MediaFrame: | |
| """A single frame from any media source.""" | |
| video_frame: np.ndarray | None = None # (H, W, 3) RGB | |
| audio_chunk: np.ndarray | None = None # (samples,) float32 | |
| timestamp: float = 0.0 | |
| class BaseCapture: | |
| """Base class for media capture sources.""" | |
| def __init__(self, fps: float = 1.0): | |
| self.fps = fps | |
| self._running = False | |
| self._buffer: deque[MediaFrame] = deque(maxlen=300) | |
| self._thread: threading.Thread | None = None | |
| self._lock = threading.Lock() | |
| def start(self): | |
| self._running = True | |
| self._thread = threading.Thread(target=self._capture_loop, daemon=True) | |
| self._thread.start() | |
| def stop(self): | |
| self._running = False | |
| if self._thread: | |
| self._thread.join(timeout=3.0) | |
| def get_latest_frame(self) -> MediaFrame | None: | |
| with self._lock: | |
| return self._buffer[-1] if self._buffer else None | |
| def get_all_frames(self) -> list[MediaFrame]: | |
| with self._lock: | |
| frames = list(self._buffer) | |
| return frames | |
| def is_running(self) -> bool: | |
| return self._running | |
| def frame_count(self) -> int: | |
| return len(self._buffer) | |
| def _capture_loop(self): | |
| raise NotImplementedError | |
| class WebcamCapture(BaseCapture): | |
| """Capture frames from webcam using OpenCV.""" | |
| def __init__(self, camera_index: int = 0, fps: float = 1.0, resolution: tuple = (640, 480)): | |
| super().__init__(fps) | |
| self.camera_index = camera_index | |
| self.resolution = resolution | |
| def _capture_loop(self): | |
| try: | |
| import cv2 | |
| except ImportError: | |
| logger.error("OpenCV not installed. Run: pip install opencv-python") | |
| self._running = False | |
| return | |
| cap = cv2.VideoCapture(self.camera_index) | |
| if not cap.isOpened(): | |
| logger.error(f"Cannot open camera {self.camera_index}") | |
| self._running = False | |
| return | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0]) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1]) | |
| start_time = time.time() | |
| interval = 1.0 / self.fps | |
| try: | |
| while self._running: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # BGR -> RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| media_frame = MediaFrame( | |
| video_frame=frame_rgb, | |
| timestamp=time.time() - start_time, | |
| ) | |
| with self._lock: | |
| self._buffer.append(media_frame) | |
| time.sleep(interval) | |
| finally: | |
| cap.release() | |
| class ScreenCapture(BaseCapture): | |
| """Capture screen frames using mss.""" | |
| def __init__(self, fps: float = 1.0, region: dict | None = None): | |
| super().__init__(fps) | |
| self.region = region # {"left": 0, "top": 0, "width": 1920, "height": 1080} | |
| def _capture_loop(self): | |
| try: | |
| import mss | |
| from PIL import Image | |
| except ImportError: | |
| logger.error("mss/PIL not installed. Run: pip install mss Pillow") | |
| self._running = False | |
| return | |
| start_time = time.time() | |
| interval = 1.0 / self.fps | |
| with mss.mss() as sct: | |
| monitor = self.region or sct.monitors[1] # Primary monitor | |
| while self._running: | |
| screenshot = sct.grab(monitor) | |
| img = Image.frombytes("RGB", screenshot.size, screenshot.bgra, "raw", "BGRX") | |
| frame = np.array(img) | |
| media_frame = MediaFrame( | |
| video_frame=frame, | |
| timestamp=time.time() - start_time, | |
| ) | |
| with self._lock: | |
| self._buffer.append(media_frame) | |
| time.sleep(interval) | |
| class FileStreamer(BaseCapture): | |
| """Stream a video file frame-by-frame at real-time speed.""" | |
| def __init__(self, file_path: str, fps: float = 1.0): | |
| super().__init__(fps) | |
| self.file_path = file_path | |
| def _capture_loop(self): | |
| try: | |
| import cv2 | |
| except ImportError: | |
| logger.error("OpenCV not installed. Run: pip install opencv-python") | |
| self._running = False | |
| return | |
| cap = cv2.VideoCapture(self.file_path) | |
| if not cap.isOpened(): | |
| logger.error(f"Cannot open video: {self.file_path}") | |
| self._running = False | |
| return | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| # Skip frames to match our target FPS | |
| frame_skip = max(1, int(video_fps / self.fps)) | |
| frame_idx = 0 | |
| start_time = time.time() | |
| interval = 1.0 / self.fps | |
| try: | |
| while self._running: | |
| ret, frame = cap.read() | |
| if not ret: | |
| self._running = False | |
| break | |
| frame_idx += 1 | |
| if frame_idx % frame_skip != 0: | |
| continue | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| media_frame = MediaFrame( | |
| video_frame=frame_rgb, | |
| timestamp=time.time() - start_time, | |
| ) | |
| with self._lock: | |
| self._buffer.append(media_frame) | |
| time.sleep(interval) | |
| finally: | |
| cap.release() | |
| def get_capture_source(source_type: str, **kwargs) -> BaseCapture: | |
| """Factory function to create a capture source.""" | |
| sources = { | |
| "webcam": WebcamCapture, | |
| "screen": ScreenCapture, | |
| "file": FileStreamer, | |
| } | |
| if source_type not in sources: | |
| raise ValueError(f"Unknown source: {source_type}. Choose from {list(sources)}") | |
| return sources[source_type](**kwargs) | |