| |
| |
| """ |
| Quiet, one-shot startup self-check for HF Spaces. |
| |
| What it does: |
| - loads SAM2Loader + MatAnyoneLoader (device from env or cuda/cpu auto) |
| - runs a minimal first-frame path (synthetic frame) to validate |
| - caches status in module state for later UI queries |
| - does NOT print unless failure; logs via `BackgroundFX`/root logger |
| |
| Control via env: |
| - DISABLE_SELF_CHECK=1 → skip entirely |
| - SELF_CHECK_DEVICE=cpu|cuda → override device |
| - SELF_CHECK_TIMEOUT=seconds → default 45 |
| """ |
|
|
| from __future__ import annotations |
| import os, time, threading, logging |
| from typing import Optional, Dict, Any |
|
|
| import numpy as np |
| import cv2 |
| import torch |
|
|
| |
| from models.loaders.sam2_loader import SAM2Loader |
| from models.loaders.matanyone_loader import MatAnyoneLoader |
| from processing.two_stage.two_stage_processor import TwoStageProcessor |
|
|
| logger = logging.getLogger("BackgroundFX") or logging.getLogger(__name__) |
|
|
| |
| _SELF_CHECK_LOCK = threading.Lock() |
| _SELF_CHECK_DONE = False |
| _SELF_CHECK_OK = False |
| _SELF_CHECK_MSG = "Self-check did not run yet." |
| _SELF_CHECK_DURATION = 0.0 |
|
|
| def _pick_device() -> str: |
| dev = os.environ.get("SELF_CHECK_DEVICE", "").strip().lower() |
| if dev in ("cpu", "cuda"): |
| return dev |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def _synth_frame(w=640, h=360) -> np.ndarray: |
| """ |
| Create a simple BGR frame with a 'person-like' central blob over green. |
| We just need a plausible image; quality doesn’t matter for self-check. |
| """ |
| img = np.zeros((h, w, 3), np.uint8) |
| |
| img[:] = (40, 40, 40) |
| |
| cv2.rectangle(img, (int(0.65*w), 0), (w, h), (0, 255, 0), -1) |
| |
| cx, cy = w//3, h//2 |
| cv2.ellipse(img, (cx, cy-40), (35, 45), 0, 0, 360, (60, 60, 200), -1) |
| cv2.rectangle(img, (cx-40, cy-10), (cx+40, cy+80), (60, 60, 200), -1) |
| return img |
|
|
| def _run_once(timeout_s: float = 45.0) -> tuple[bool, str, float]: |
| t0 = time.time() |
| device = _pick_device() |
| try: |
| |
| sam = SAM2Loader(device=device).load("auto") |
| if sam is None: |
| return False, "SAM2 failed to load", time.time()-t0 |
|
|
| |
| bgr = _synth_frame() |
| sam.set_image(bgr) |
| out = sam.predict(point_coords=None, point_labels=None) |
| masks = out.get("masks", None) |
| h, w = bgr.shape[:2] |
| if masks is None or len(masks) == 0: |
| logger.warning("Self-check: SAM2 returned no masks; accepting fallback.") |
| mask0 = np.ones((h, w), np.float32) |
| else: |
| mask0 = masks[0].astype(np.float32) |
| if mask0.shape != (h, w): |
| mask0 = cv2.resize(mask0, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
| |
| session = MatAnyoneLoader(device=device).load() |
| if session is None: |
| return False, "MatAnyone failed to load", time.time()-t0 |
|
|
| |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) |
| alpha0 = session(rgb, mask0) |
| if not isinstance(alpha0, np.ndarray) or alpha0.shape != (h, w): |
| return False, f"MatAnyone alpha shape unexpected: {getattr(alpha0, 'shape', None)}", time.time()-t0 |
|
|
| |
| _ = TwoStageProcessor(sam2_predictor=sam, matanyone_model=session) |
|
|
| return True, "OK", time.time()-t0 |
|
|
| except Exception as e: |
| return False, f"Self-check error: {e}", time.time()-t0 |
| finally: |
| |
| dur = time.time()-t0 |
| if dur > timeout_s: |
| logger.warning(f"Self-check exceeded timeout {timeout_s:.1f}s (took {dur:.2f}s)") |
| return locals().get("sam", None) is not None and locals().get("session", None) is not None, \ |
| locals().get("e", None) and f"Self-check error: {e}" or "OK", \ |
| dur |
|
|
| def _runner(timeout_s: float): |
| global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION |
| ok, msg, dur = _run_once(timeout_s=timeout_s) |
| with _SELF_CHECK_LOCK: |
| _SELF_CHECK_DONE = True |
| _SELF_CHECK_OK = bool(ok and msg == "OK") |
| _SELF_CHECK_MSG = msg |
| _SELF_CHECK_DURATION = float(dur) |
| if _SELF_CHECK_OK: |
| logger.info(f"✅ Startup self-check OK in {dur:.2f}s") |
| else: |
| logger.error(f"❌ Startup self-check FAILED in {dur:.2f}s: {msg}") |
|
|
| def launch_self_check_async(timeout_s: Optional[float] = None): |
| """ |
| Fire-and-forget startup check. No effect if disabled or already started. |
| """ |
| if os.environ.get("DISABLE_SELF_CHECK", "0") == "1": |
| logger.info("Self-check disabled via DISABLE_SELF_CHECK=1") |
| with _SELF_CHECK_LOCK: |
| global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION |
| _SELF_CHECK_DONE = True |
| _SELF_CHECK_OK = True |
| _SELF_CHECK_MSG = "Disabled" |
| _SELF_CHECK_DURATION = 0.0 |
| return |
|
|
| timeout_s = float(os.environ.get("SELF_CHECK_TIMEOUT", str(timeout_s or 45.0))) |
| |
| with _SELF_CHECK_LOCK: |
| if getattr(launch_self_check_async, "_started", False): |
| return |
| launch_self_check_async._started = True |
| th = threading.Thread(target=_runner, args=(timeout_s,), daemon=True) |
| th.start() |
|
|
| def get_self_check_status() -> Dict[str, Any]: |
| with _SELF_CHECK_LOCK: |
| return { |
| "done": _SELF_CHECK_DONE, |
| "ok": _SELF_CHECK_OK, |
| "message": _SELF_CHECK_MSG, |
| "duration": _SELF_CHECK_DURATION, |
| } |
|
|