| |
| """ |
| BackgroundFX Pro - Model Loading & Utilities (Hardened) |
| ====================================================== |
| - Avoids heavy CUDA/Hydra work at import time |
| - Adds timeouts to subprocess probes |
| - Safer sys.path wiring for third_party repos |
| - MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession |
| |
| Changes (2025-09-16): |
| - Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0 |
| - Updated load_matany to apply T=1 squeeze patch before InferenceCore import |
| - Added patch status logging and MatAnyone version |
| - Added InferenceCore attributes logging for debugging |
| - Fixed InferenceCore import path to matanyone.inference.inference_core |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import cv2 |
| import subprocess |
| import inspect |
| import logging |
| import importlib.metadata |
| from pathlib import Path |
| from typing import Optional, Tuple, Dict, Any, Union, Callable |
|
|
| import numpy as np |
| import yaml |
|
|
| |
| try: |
| import torch |
| except ImportError: |
| torch = None |
|
|
| |
| |
| |
| logger = logging.getLogger("backgroundfx_pro") |
| if not logger.handlers: |
| _h = logging.StreamHandler() |
| _h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) |
| logger.addHandler(_h) |
| logger.setLevel(logging.INFO) |
|
|
| |
| try: |
| cv_threads = int(os.environ.get("CV_THREADS", "1")) |
| if hasattr(cv2, "setNumThreads"): |
| cv2.setNumThreads(cv_threads) |
| except Exception: |
| pass |
|
|
| |
| |
| |
| try: |
| import mediapipe as mp |
| _HAS_MEDIAPIPE = True |
| except Exception: |
| _HAS_MEDIAPIPE = False |
|
|
| |
| |
| |
| ROOT = Path(__file__).resolve().parent.parent |
| TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve() |
| TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve() |
|
|
| def _add_sys_path(p: Path) -> None: |
| if p.exists(): |
| p_str = str(p) |
| if p_str not in sys.path: |
| sys.path.insert(0, p_str) |
| else: |
| logger.warning(f"third_party path not found: {p}") |
|
|
| _add_sys_path(TP_SAM2) |
| _add_sys_path(TP_MATANY) |
|
|
| |
| |
| |
| def _torch(): |
| try: |
| import torch |
| return torch |
| except Exception as e: |
| logger.warning(f"[models.safe-torch] import failed: {e}") |
| return None |
|
|
| def _has_cuda() -> bool: |
| t = _torch() |
| if t is None: |
| return False |
| try: |
| return bool(t.cuda.is_available()) |
| except Exception as e: |
| logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}") |
| return False |
|
|
| def _pick_device(env_key: str) -> str: |
| requested = os.environ.get(env_key, "").strip().lower() |
| has_cuda = _has_cuda() |
| |
| |
| cuda_env_vars = { |
| 'FORCE_CUDA_DEVICE': os.environ.get('FORCE_CUDA_DEVICE', ''), |
| 'CUDA_MEMORY_FRACTION': os.environ.get('CUDA_MEMORY_FRACTION', ''), |
| 'PYTORCH_CUDA_ALLOC_CONF': os.environ.get('PYTORCH_CUDA_ALLOC_CONF', ''), |
| 'REQUIRE_CUDA': os.environ.get('REQUIRE_CUDA', ''), |
| 'SAM2_DEVICE': os.environ.get('SAM2_DEVICE', ''), |
| 'MATANY_DEVICE': os.environ.get('MATANY_DEVICE', ''), |
| } |
| logger.info(f"CUDA environment variables: {cuda_env_vars}") |
| |
| logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}") |
| |
| |
| if has_cuda and requested not in {"cpu"}: |
| logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')") |
| return "cuda" |
| elif requested in {"cuda", "cpu"}: |
| logger.info(f"Using explicitly requested device: {requested}") |
| return requested |
| |
| result = "cuda" if has_cuda else "cpu" |
| logger.info(f"Auto-selected device: {result}") |
| return result |
|
|
| |
| |
| |
| def _ffmpeg_bin() -> str: |
| return os.environ.get("FFMPEG_BIN", "ffmpeg") |
|
|
| def _probe_ffmpeg(timeout: int = 2) -> bool: |
| try: |
| subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout) |
| return True |
| except Exception: |
| return False |
|
|
| def _ensure_dir(p: Path) -> None: |
| p.mkdir(parents=True, exist_ok=True) |
|
|
| def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]: |
| cap = cv2.VideoCapture(str(video_path)) |
| if not cap.isOpened(): |
| return None, 0, (0, 0) |
| fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
| ok, frame = cap.read() |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
| cap.release() |
| if not ok: |
| return None, fps, (w, h) |
| return frame, fps, (w, h) |
|
|
| def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str: |
| if mask.dtype == bool: |
| mask = (mask.astype(np.uint8) * 255) |
| elif mask.dtype != np.uint8: |
| mask = np.clip(mask, 0, 255).astype(np.uint8) |
| cv2.imwrite(str(path), mask) |
| return str(path) |
|
|
| def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray: |
| tw, th = target_wh |
| h, w = image.shape[:2] |
| if h == 0 or w == 0 or tw == 0 or th == 0: |
| return image |
| scale = min(tw / w, th / h) |
| nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale))) |
| resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) |
| canvas = np.zeros((th, tw, 3), dtype=resized.dtype) |
| x0 = (tw - nw) // 2 |
| y0 = (th - nh) // 2 |
| canvas[y0:y0+nh, x0:x0+nw] = resized |
| return canvas |
|
|
| def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter: |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size) |
|
|
| def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool: |
| """Copy video from silent_video + audio from src_video into out_path (AAC).""" |
| try: |
| cmd = [ |
| _ffmpeg_bin(), "-y", |
| "-i", str(silent_video), |
| "-i", str(src_video), |
| "-map", "0:v:0", |
| "-map", "1:a:0?", |
| "-c:v", "copy", |
| "-c:a", "aac", "-b:a", "192k", |
| "-shortest", |
| str(out_path) |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| return True |
| except Exception as e: |
| logger.warning(f"Audio mux failed; returning silent video. Reason: {e}") |
| return False |
|
|
| |
| |
| |
| def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray: |
| if alpha.dtype != np.float32: |
| a = alpha.astype(np.float32) |
| if a.max() > 1.0: |
| a = a / 255.0 |
| else: |
| a = alpha.copy() |
|
|
| a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8) |
| if erode_px > 0: |
| k = max(1, int(erode_px)) |
| a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
| if dilate_px > 0: |
| k = max(1, int(dilate_px)) |
| a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
| a = a_u8.astype(np.float32) / 255.0 |
|
|
| if blur_px and blur_px > 0: |
| rad = max(1, int(round(blur_px))) |
| a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0) |
|
|
| return np.clip(a, 0.0, 1.0) |
|
|
| def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
| x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0) |
| return np.power(x, gamma) |
|
|
| def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
| x = np.clip(lin, 0.0, 1.0) |
| return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8) |
|
|
| def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray: |
| r = max(1, int(radius)) |
| inv = 1.0 - alpha01 |
| inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0) |
| lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount)) |
| return lw |
|
|
| def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray: |
| w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) |
| w = np.clip(w, 0.0, 1.0) |
| hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32) |
| H, S, V = cv2.split(hsv) |
| S = S * (1.0 - amount * w) |
| hsv2 = cv2.merge([H, np.clip(S, 0, 255), V]) |
| out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB) |
| return out |
|
|
| def _composite_frame_pro( |
| fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray, |
| erode_px: int = None, dilate_px: int = None, blur_px: float = None, |
| lw_radius: int = None, lw_amount: float = None, despill_amount: float = None |
| ) -> np.ndarray: |
| erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1")) |
| dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2")) |
| blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5")) |
| lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5")) |
| lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18")) |
| despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35")) |
|
|
| a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px) |
| fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount) |
|
|
| fg_lin = _to_linear(fg_rgb) |
| bg_lin = _to_linear(bg_rgb) |
| lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount) |
| lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8)) |
|
|
| comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin |
| comp = _to_srgb(comp_lin) |
| return comp |
|
|
| |
| |
| |
| def _resolve_sam2_cfg(cfg_str: str) -> str: |
| """Resolve SAM2 config path - return relative path for Hydra compatibility.""" |
| logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}") |
| |
| |
| tp_sam2 = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2") |
| logger.info(f"TP_SAM2 = {tp_sam2}") |
| |
| |
| candidate = os.path.join(tp_sam2, cfg_str) |
| logger.info(f"Candidate path: {candidate}") |
| logger.info(f"Candidate exists: {os.path.exists(candidate)}") |
| |
| if os.path.exists(candidate): |
| |
| if cfg_str.startswith("sam2/configs/"): |
| relative_path = cfg_str.replace("sam2/configs/", "configs/") |
| else: |
| relative_path = cfg_str |
| logger.info(f"Returning Hydra-compatible relative path: {relative_path}") |
| return relative_path |
| |
| |
| fallbacks = [ |
| os.path.join(tp_sam2, "sam2", cfg_str), |
| os.path.join(tp_sam2, "configs", cfg_str), |
| ] |
| |
| for fallback in fallbacks: |
| logger.info(f"Trying fallback: {fallback}") |
| if os.path.exists(fallback): |
| |
| if "configs/" in fallback: |
| relative_path = "configs/" + fallback.split("configs/")[-1] |
| logger.info(f"Returning fallback relative path: {relative_path}") |
| return relative_path |
| |
| logger.warning(f"Config not found, returning original: {cfg_str}") |
| return cfg_str |
|
|
| def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]: |
| """If config references 'hieradet', try to find a 'hiera' config.""" |
| try: |
| with open(cfg_path, "r") as f: |
| data = yaml.safe_load(f) |
| model = data.get("model", {}) or {} |
| enc = model.get("image_encoder") or {} |
| trunk = enc.get("trunk") or {} |
| target = trunk.get("_target_") or trunk.get("target") |
| if isinstance(target, str) and "hieradet" in target: |
| for y in TP_SAM2.rglob("*.yaml"): |
| try: |
| with open(y, "r") as f2: |
| d2 = yaml.safe_load(f2) or {} |
| e2 = (d2.get("model", {}) or {}).get("image_encoder") or {} |
| t2 = (e2.get("trunk") or {}) |
| tgt2 = t2.get("_target_") or t2.get("target") |
| if isinstance(tgt2, str) and ".hiera." in tgt2: |
| logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}") |
| return str(y) |
| except Exception: |
| continue |
| except Exception: |
| pass |
| return None |
|
|
| def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
| """Robust SAM2 loader with config resolution and error handling.""" |
| meta = {"sam2_import_ok": False, "sam2_init_ok": False} |
| try: |
| from sam2.build_sam import build_sam2 |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| meta["sam2_import_ok"] = True |
| except Exception as e: |
| logger.warning(f"SAM2 import failed: {e}") |
| return None, False, meta |
|
|
| |
| if torch and torch.cuda.is_available(): |
| mem_before = torch.cuda.memory_allocated() / 1024**3 |
| logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB") |
|
|
| device = _pick_device("SAM2_DEVICE") |
| cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml") |
| cfg = _resolve_sam2_cfg(cfg_env) |
| ckpt = os.environ.get("SAM2_CHECKPOINT", "") |
|
|
| def _try_build(cfg_path: str): |
| logger.info(f"_try_build called with cfg_path: {cfg_path}") |
| params = set(inspect.signature(build_sam2).parameters.keys()) |
| logger.info(f"build_sam2 parameters: {list(params)}") |
| kwargs = {} |
| if "config_file" in params: |
| kwargs["config_file"] = cfg_path |
| logger.info(f"Using config_file parameter: {cfg_path}") |
| elif "model_cfg" in params: |
| kwargs["model_cfg"] = cfg_path |
| logger.info(f"Using model_cfg parameter: {cfg_path}") |
| if ckpt: |
| if "checkpoint" in params: |
| kwargs["checkpoint"] = ckpt |
| elif "ckpt_path" in params: |
| kwargs["ckpt_path"] = ckpt |
| elif "weights" in params: |
| kwargs["weights"] = ckpt |
| if "device" in params: |
| kwargs["device"] = device |
| try: |
| logger.info(f"Calling build_sam2 with kwargs: {kwargs}") |
| result = build_sam2(**kwargs) |
| logger.info(f"build_sam2 succeeded with kwargs") |
| |
| if hasattr(result, 'device'): |
| logger.info(f"SAM2 model device: {result.device}") |
| elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'): |
| logger.info(f"SAM2 model device: {result.image_encoder.device}") |
| return result |
| except TypeError as e: |
| logger.info(f"build_sam2 kwargs failed: {e}, trying positional args") |
| pos = [cfg_path] |
| if ckpt: |
| pos.append(ckpt) |
| if "device" not in kwargs: |
| pos.append(device) |
| logger.info(f"Calling build_sam2 with positional args: {pos}") |
| result = build_sam2(*pos) |
| logger.info(f"build_sam2 succeeded with positional args") |
| return result |
|
|
| try: |
| try: |
| sam = _try_build(cfg) |
| except Exception: |
| alt_cfg = _find_hiera_config_if_hieradet(cfg) |
| if alt_cfg: |
| sam = _try_build(alt_cfg) |
| else: |
| raise |
| |
| if sam is not None: |
| predictor = SAM2ImagePredictor(sam) |
| meta["sam2_init_ok"] = True |
| meta["sam2_device"] = device |
| return predictor, True, meta |
| else: |
| return None, False, meta |
| |
| except Exception as e: |
| logger.error(f"SAM2 loading failed: {e}") |
| return None, False, meta |
|
|
| def run_sam2_mask(predictor: object, |
| first_frame_bgr: np.ndarray, |
| point: Optional[Tuple[int, int]] = None, |
| auto: bool = False) -> Tuple[Optional[np.ndarray], bool]: |
| """Return (mask_uint8_0_255, ok).""" |
| if predictor is None: |
| return None, False |
| try: |
| rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
| predictor.set_image(rgb) |
|
|
| if auto: |
| h, w = rgb.shape[:2] |
| box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)]) |
| masks, _, _ = predictor.predict(box=box) |
| elif point is not None: |
| x, y = int(point[0]), int(point[1]) |
| pts = np.array([[x, y]], dtype=np.int32) |
| labels = np.array([1], dtype=np.int32) |
| masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels) |
| else: |
| h, w = rgb.shape[:2] |
| box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)]) |
| masks, _, _ = predictor.predict(box=box) |
|
|
| if masks is None or len(masks) == 0: |
| return None, False |
|
|
| m = masks[0].astype(np.uint8) * 255 |
| return m, True |
| except Exception as e: |
| logger.warning(f"SAM2 mask failed: {e}") |
| return None, False |
|
|
| def _refine_mask_grabcut(image_bgr: np.ndarray, |
| mask_u8: np.ndarray, |
| iters: int = None, |
| trimap_erode: int = None, |
| trimap_dilate: int = None) -> np.ndarray: |
| """Use SAM2 seed as initialization for GrabCut refinement.""" |
| iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters) |
| e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode) |
| d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate) |
|
|
| h, w = mask_u8.shape[:2] |
| m = (mask_u8 > 127).astype(np.uint8) * 255 |
|
|
| sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1) |
| sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1) |
|
|
| gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8) |
| gc_mask[sure_bg > 0] = cv2.GC_BGD |
| gc_mask[sure_fg > 0] = cv2.GC_FGD |
|
|
| bgdModel = np.zeros((1, 65), np.float64) |
| fgdModel = np.zeros((1, 65), np.float64) |
| try: |
| cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK) |
| out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
| out = cv2.medianBlur(out, 5) |
| return out |
| except Exception as e: |
| logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}") |
| return m |
|
|
| |
| |
| |
| def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
| """ |
| Probe MatAnyone availability with T=1 squeeze patch for conv2d compatibility. |
| Returns (None, available, meta); actual instantiation happens in MatAnyoneSession. |
| """ |
| meta = {"matany_import_ok": False, "matany_init_ok": False} |
| enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower() |
| if enable_env in {"0", "false", "off", "no"}: |
| logger.info("MatAnyone disabled by ENABLE_MATANY=0.") |
| meta["disabled"] = True |
| return None, False, meta |
|
|
| |
| try: |
| from .matany_compat_patch import apply_matany_t1_squeeze_guard |
| if apply_matany_t1_squeeze_guard(): |
| logger.info("[MatAnyCompat] T=1 squeeze guard applied") |
| meta["patch_applied"] = True |
| else: |
| logger.warning("[MatAnyCompat] T=1 squeeze patch failed; conv2d errors may occur") |
| meta["patch_applied"] = False |
| except Exception as e: |
| logger.warning(f"[MatAnyCompat] Patch import failed: {e}") |
| meta["patch_applied"] = False |
|
|
| try: |
| from matanyone.inference.inference_core import InferenceCore |
| meta["matany_import_ok"] = True |
| |
| try: |
| version = importlib.metadata.version("matanyone") |
| logger.info(f"[MATANY] MatAnyone version: {version}") |
| except Exception: |
| logger.info("[MATANY] MatAnyone version unknown") |
| logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}") |
| device = _pick_device("MATANY_DEVICE") |
| repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone") |
| meta["matany_repo_id"] = repo_id |
| meta["matany_device"] = device |
| return None, True, meta |
| except Exception as e: |
| logger.warning(f"MatAnyone import failed: {e}") |
| return None, False, meta |
|
|
| |
| |
| |
| def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray: |
| """Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255.""" |
| h, w = first_frame_bgr.shape[:2] |
| if _HAS_MEDIAPIPE: |
| try: |
| mp_selfie = mp.solutions.selfie_segmentation |
| with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter: |
| rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
| res = segmenter.process(rgb) |
| m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255 |
| m = cv2.medianBlur(m, 5) |
| return m |
| except Exception as e: |
| logger.warning(f"MediaPipe fallback failed: {e}") |
|
|
| |
| mask = np.zeros((h, w), np.uint8) |
| rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h)) |
| bgdModel = np.zeros((1, 65), np.float64) |
| fgdModel = np.zeros((1, 65), np.float64) |
| try: |
| cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT) |
| mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
| return mask_bin |
| except Exception as e: |
| logger.warning(f"GrabCut failed: {e}") |
| return np.zeros((h, w), dtype=np.uint8) |
|
|
| def composite_video(fg_path: Union[str, Path], |
| alpha_path: Union[str, Path], |
| bg_image_path: Union[str, Path], |
| out_path: Union[str, Path], |
| fps: int, |
| size: Tuple[int, int]) -> bool: |
| """Blend MatAnyone FG+ALPHA over background using pro compositor.""" |
| fg_cap = cv2.VideoCapture(str(fg_path)) |
| al_cap = cv2.VideoCapture(str(alpha_path)) |
| if not fg_cap.isOpened() or not al_cap.isOpened(): |
| return False |
|
|
| w, h = size |
| bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
| if bg is None: |
| bg = np.full((h, w, 3), 127, dtype=np.uint8) |
| bg_f = _resize_keep_ar(bg, (w, h)) |
|
|
| if _probe_ffmpeg(): |
| tmp_out = Path(str(out_path) + ".tmp.mp4") |
| writer = _video_writer(tmp_out, fps, (w, h)) |
| post_h264 = True |
| else: |
| writer = _video_writer(Path(out_path), fps, (w, h)) |
| post_h264 = False |
|
|
| ok_any = False |
| try: |
| while True: |
| ok_fg, fg = fg_cap.read() |
| ok_al, al = al_cap.read() |
| if not ok_fg or not ok_al: |
| break |
| fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC) |
| al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
|
|
| comp = _composite_frame_pro( |
| cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), |
| al_gray, |
| cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
| ) |
| writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| ok_any = True |
| finally: |
| fg_cap.release() |
| al_cap.release() |
| writer.release() |
|
|
| if post_h264 and ok_any: |
| try: |
| cmd = [ |
| _ffmpeg_bin(), "-y", |
| "-i", str(tmp_out), |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
| str(out_path) |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| tmp_out.unlink(missing_ok=True) |
| except Exception as e: |
| logger.warning(f"ffmpeg finalize failed: {e}") |
| Path(out_path).unlink(missing_ok=True) |
| tmp_out.replace(out_path) |
|
|
| return ok_any |
|
|
| def fallback_composite(video_path: Union[str, Path], |
| mask_path: Union[str, Path], |
| bg_image_path: Union[str, Path], |
| out_path: Union[str, Path]) -> bool: |
| """Static-mask compositing using pro compositor.""" |
| mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
| cap = cv2.VideoCapture(str(video_path)) |
| if mask is None or not cap.isOpened(): |
| return False |
|
|
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
| fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
|
|
| bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
| if bg is None: |
| bg = np.full((h, w, 3), 127, dtype=np.uint8) |
|
|
| mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
| bg_f = _resize_keep_ar(bg, (w, h)) |
|
|
| if _probe_ffmpeg(): |
| tmp_out = Path(str(out_path) + ".tmp.mp4") |
| writer = _video_writer(tmp_out, fps, (w, h)) |
| use_post_ffmpeg = True |
| else: |
| writer = _video_writer(Path(out_path), fps, (w, h)) |
| use_post_ffmpeg = False |
|
|
| ok_any = False |
| try: |
| while True: |
| ok, frame = cap.read() |
| if not ok: |
| break |
| comp = _composite_frame_pro( |
| cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), |
| mask_resized, |
| cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
| ) |
| writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| ok_any = True |
| finally: |
| cap.release() |
| writer.release() |
|
|
| if use_post_ffmpeg and ok_any: |
| try: |
| cmd = [ |
| _ffmpeg_bin(), "-y", |
| "-i", str(tmp_out), |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
| str(out_path) |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| tmp_out.unlink(missing_ok=True) |
| except Exception as e: |
| logger.warning(f"ffmpeg H.264 finalize failed: {e}") |
| Path(out_path).unlink(missing_ok=True) |
| tmp_out.replace(out_path) |
|
|
| return ok_any |
|
|
| |
| |
| |
| def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray: |
| y, x = np.mgrid[0:h, 0:w] |
| c = ((x // tile) + (y // tile)) % 2 |
| a = np.where(c == 0, 200, 150).astype(np.uint8) |
| return np.stack([a, a, a], axis=-1) |
|
|
| def _build_stage_a_rgba_vp9_from_fg_alpha( |
| fg_path: Union[str, Path], |
| alpha_path: Union[str, Path], |
| out_webm: Union[str, Path], |
| fps: int, |
| size: Tuple[int, int], |
| src_audio: Optional[Union[str, Path]] = None, |
| ) -> bool: |
| if not _probe_ffmpeg(): |
| return False |
| w, h = size |
| try: |
| cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)] |
| if src_audio: |
| cmd += ["-i", str(src_audio)] |
| fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \ |
| f"[0:v]scale={w}:{h},fps={fps}[fg];" \ |
| f"[fg][al]alphamerge[outv]" |
| cmd += ["-filter_complex", fcx, "-map", "[outv]"] |
| if src_audio: |
| cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"] |
| cmd += [ |
| "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
| "-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
| "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| return True |
| except Exception as e: |
| logger.warning(f"Stage-A VP9(alpha) build failed: {e}") |
| return False |
|
|
| def _build_stage_a_rgba_vp9_from_mask( |
| video_path: Union[str, Path], |
| mask_png: Union[str, Path], |
| out_webm: Union[str, Path], |
| fps: int, |
| size: Tuple[int, int], |
| ) -> bool: |
| if not _probe_ffmpeg(): |
| return False |
| w, h = size |
| try: |
| cmd = [ |
| _ffmpeg_bin(), "-y", |
| "-i", str(video_path), |
| "-loop", "1", "-i", str(mask_png), |
| "-filter_complex", |
| f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" |
| f"[0:v]scale={w}:{h},fps={fps}[fg];" |
| f"[fg][al]alphamerge[outv]", |
| "-map", "[outv]", |
| "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
| "-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
| "-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| return True |
| except Exception as e: |
| logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}") |
| return False |
|
|
| def _build_stage_a_checkerboard_from_fg_alpha( |
| fg_path: Union[str, Path], |
| alpha_path: Union[str, Path], |
| out_mp4: Union[str, Path], |
| fps: int, |
| size: Tuple[int, int], |
| ) -> bool: |
| fg_cap = cv2.VideoCapture(str(fg_path)) |
| al_cap = cv2.VideoCapture(str(alpha_path)) |
| if not fg_cap.isOpened() or not al_cap.isOpened(): |
| return False |
| w, h = size |
| writer = _video_writer(Path(out_mp4), fps, (w, h)) |
| bg = _checkerboard_bg(w, h) |
| ok_any = False |
| try: |
| while True: |
| okf, fg = fg_cap.read() |
| oka, al = al_cap.read() |
| if not okf or not oka: |
| break |
| fg = cv2.resize(fg, (w, h)) |
| al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
| comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg) |
| writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| ok_any = True |
| finally: |
| fg_cap.release() |
| al_cap.release() |
| writer.release() |
| return ok_any |
|
|
| def _build_stage_a_checkerboard_from_mask( |
| video_path: Union[str, Path], |
| mask_png: Union[str, Path], |
| out_mp4: Union[str, Path], |
| fps: int, |
| size: Tuple[int, int], |
| ) -> bool: |
| cap = cv2.VideoCapture(str(video_path)) |
| if not cap.isOpened(): |
| return False |
| w, h = size |
| mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE) |
| if mask is None: |
| return False |
| mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
| writer = _video_writer(Path(out_mp4), fps, (w, h)) |
| bg = _checkerboard_bg(w, h) |
| ok_any = False |
| try: |
| while True: |
| ok, frame = cap.read() |
| if not ok: |
| break |
| frame = cv2.resize(frame, (w, h)) |
| comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg) |
| writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
| ok_any = True |
| finally: |
| cap.release() |
| writer.release() |
| return ok_any |
|
|
| |
| |
| |
| def run_matany( |
| video_path: Union[str, Path], |
| mask_path: Optional[Union[str, Path]], |
| out_dir: Union[str, Path], |
| device: Optional[str] = None, |
| progress_callback: Optional[Callable[[float, str], None]] = None, |
| ) -> Tuple[Path, Path]: |
| """ |
| Run MatAnyone streaming matting via our shape-guarded adapter. |
| Returns (alpha_mp4_path, fg_mp4_path). |
| Raises MatAnyError on failure. |
| """ |
| from .matanyone_loader import MatAnyoneSession, MatAnyError |
|
|
| session = MatAnyoneSession(device=device, precision="auto") |
| alpha_p, fg_p = session.process_stream( |
| video_path=Path(video_path), |
| seed_mask_path=Path(mask_path) if mask_path else None, |
| out_dir=Path(out_dir), |
| progress_cb=progress_callback, |
| ) |
| return alpha_p, fg_p |