| |
| |
| |
| |
| """ |
| Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds) |
| - Strict tensor shapes for MatAnyone (image: 3xHxW, first-frame prob mask: 1xHxW) |
| - First frame uses PROB path (no idx_mask / objects) to avoid assertion |
| - Memory management & cleanup |
| - SDXL / Playground / OpenAI backgrounds |
| - Gradio UI with "CHAPTER" dividers |
| - FIXED: Enhanced positioning with debug logging and coordinate precision |
| """ |
|
|
| |
| |
| |
| import os |
| import sys |
| import gc |
| import cv2 |
| import psutil |
| import time |
| import json |
| import base64 |
| import random |
| import shutil |
| import logging |
| import traceback |
| import subprocess |
| import tempfile |
| import threading |
| from dataclasses import dataclass |
| from contextlib import contextmanager |
| from pathlib import Path |
| from typing import Optional, Tuple, List |
|
|
| import numpy as np |
| from PIL import Image |
| import gradio as gr |
| from moviepy.editor import VideoFileClip |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger("bgx") |
|
|
| |
| os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY") |
| os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1") |
| os.environ.setdefault("PYTHONUNBUFFERED", "1") |
| os.environ.setdefault("MKL_NUM_THREADS", "4") |
| os.environ.setdefault("BFX_QUALITY", "max") |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,roundup_power2_divisions:16") |
| os.environ.setdefault("HYDRA_FULL_ERROR", "1") |
| os.environ["OMP_NUM_THREADS"] = "2" |
|
|
| |
| BASE_DIR = Path(__file__).resolve().parent |
| CHECKPOINTS = BASE_DIR / "checkpoints" |
| TEMP_DIR = BASE_DIR / "temp" |
| OUT_DIR = BASE_DIR / "outputs" |
| BACKGROUND_DIR = OUT_DIR / "backgrounds" |
| for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR, BACKGROUND_DIR): |
| p.mkdir(parents=True, exist_ok=True) |
|
|
| |
| try: |
| import torch |
| TORCH_AVAILABLE = True |
| CUDA_AVAILABLE = torch.cuda.is_available() |
| DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" |
| try: |
| if torch.backends.cuda.is_built(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| if hasattr(torch.backends, "cudnn"): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.deterministic = False |
| if CUDA_AVAILABLE: |
| torch.cuda.set_per_process_memory_fraction(0.8) |
| except Exception: |
| pass |
| except Exception: |
| TORCH_AVAILABLE = False |
| CUDA_AVAILABLE = False |
| DEVICE = "cpu" |
|
|
| |
| |
| |
| GRADIENT_PRESETS = { |
| "Blue Fade": ((128, 64, 0), (255, 128, 0)), |
| "Sunset": ((255, 128, 0), (255, 0, 128)), |
| "Green Field": ((64, 128, 64), (160, 255, 160)), |
| "Slate": ((40, 40, 48), (96, 96, 112)), |
| "Ocean": ((255, 140, 0), (255, 215, 0)), |
| "Forest": ((34, 139, 34), (144, 238, 144)), |
| "Sunset Pink": ((255, 182, 193), (255, 105, 180)), |
| "Cool Blue": ((173, 216, 230), (0, 191, 255)), |
| } |
|
|
| AI_PROMPT_SUGGESTIONS = [ |
| "Custom (write your own)", |
| "modern minimalist office with soft lighting, clean desk, blurred background", |
| "elegant conference room with large windows and city view", |
| "contemporary workspace with plants and natural light", |
| "luxury hotel lobby with marble floors and warm ambient lighting", |
| "professional studio with clean white background and soft lighting", |
| "modern corporate meeting room with glass walls and city skyline", |
| "sophisticated home office with bookshelf and warm wood tones", |
| "sleek coworking space with industrial design elements", |
| "abstract geometric patterns in blue and gold, modern art style", |
| "soft watercolor texture with pastel colors, dreamy atmosphere", |
| ] |
|
|
| def _make_vertical_gradient(width: int, height: int, c1, c2) -> np.ndarray: |
| width = max(1, int(width)) |
| height = max(1, int(height)) |
| top = np.array(c1, dtype=np.float32) |
| bot = np.array(c2, dtype=np.float32) |
| rows = np.linspace(top, bot, num=height, dtype=np.float32) |
| grad = np.repeat(rows[:, None, :], repeats=width, axis=1) |
| return np.clip(grad, 0, 255).astype(np.uint8) |
|
|
| def run_ffmpeg(args: list, fail_ok=False) -> bool: |
| cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error"] + args |
| try: |
| subprocess.run(cmd, check=True, capture_output=True) |
| return True |
| except Exception as e: |
| if not fail_ok: |
| logger.error(f"ffmpeg failed: {e}") |
| return False |
|
|
| def write_video_h264(clip, path: str, fps: Optional[int] = None, crf: int = 18, preset: str = "medium"): |
| fps = fps or max(1, int(round(getattr(clip, "fps", None) or 24))) |
| clip.write_videofile( |
| path, |
| audio=False, |
| fps=fps, |
| codec="libx264", |
| preset=preset, |
| ffmpeg_params=["-crf", str(crf), "-pix_fmt", "yuv420p", "-profile:v", "high", "-movflags", "+faststart"], |
| logger=None, |
| verbose=False, |
| ) |
|
|
| def download_file(url: str, dest: Path, name: str) -> bool: |
| if dest.exists(): |
| logger.info(f"{name} already exists") |
| return True |
| try: |
| import requests |
| logger.info(f"Downloading {name} ...") |
| with requests.get(url, stream=True, timeout=300) as r: |
| r.raise_for_status() |
| with open(dest, "wb") as f: |
| for chunk in r.iter_content(chunk_size=8192): |
| if chunk: |
| f.write(chunk) |
| return True |
| except Exception as e: |
| logger.error(f"Failed to download {name}: {e}") |
| if dest.exists(): |
| try: dest.unlink() |
| except Exception: pass |
| return False |
|
|
| def ensure_repo(repo_name: str, git_url: str) -> Optional[Path]: |
| repo_path = CHECKPOINTS / f"{repo_name}_repo" |
| if not repo_path.exists(): |
| try: |
| subprocess.run(["git", "clone", "--depth", "1", git_url, str(repo_path)], |
| check=True, timeout=300, capture_output=True) |
| logger.info(f"{repo_name} cloned") |
| except Exception as e: |
| logger.error(f"Failed to clone {repo_name}: {e}") |
| return None |
| repo_str = str(repo_path) |
| if repo_str not in sys.path: |
| sys.path.insert(0, repo_str) |
| return repo_path |
|
|
| def _reset_hydra(): |
| try: |
| from hydra.core.global_hydra import GlobalHydra |
| if GlobalHydra().is_initialized(): |
| GlobalHydra.instance().clear() |
| except Exception: |
| pass |
|
|
| |
| |
| |
| @dataclass |
| class MemoryStats: |
| cpu_percent: float |
| cpu_memory_mb: float |
| gpu_memory_mb: float = 0.0 |
| gpu_memory_reserved_mb: float = 0.0 |
| temp_files_count: int = 0 |
| temp_files_size_mb: float = 0.0 |
|
|
| class MemoryManager: |
| def __init__(self): |
| self.temp_files: List[str] = [] |
| self.cleanup_lock = threading.Lock() |
| self.torch_available = TORCH_AVAILABLE |
| self.cuda_available = CUDA_AVAILABLE |
|
|
| def get_memory_stats(self) -> MemoryStats: |
| process = psutil.Process() |
| cpu_percent = psutil.cpu_percent(interval=0.1) |
| cpu_memory_mb = process.memory_info().rss / (1024 * 1024) |
| gpu_memory_mb = 0.0 |
| gpu_memory_reserved_mb = 0.0 |
| if self.torch_available and self.cuda_available: |
| try: |
| import torch |
| gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024) |
| gpu_memory_reserved_mb = torch.cuda.memory_reserved() / (1024 * 1024) |
| except Exception: |
| pass |
|
|
| temp_count, temp_size_mb = 0, 0.0 |
| for tf in self.temp_files: |
| if os.path.exists(tf): |
| temp_count += 1 |
| try: |
| temp_size_mb += os.path.getsize(tf) / (1024 * 1024) |
| except Exception: |
| pass |
| return MemoryStats(cpu_percent, cpu_memory_mb, gpu_memory_mb, gpu_memory_reserved_mb, temp_count, temp_size_mb) |
|
|
| def register_temp_file(self, path: str): |
| with self.cleanup_lock: |
| if path not in self.temp_files: |
| self.temp_files.append(path) |
|
|
| def cleanup_temp_files(self): |
| with self.cleanup_lock: |
| cleaned = 0 |
| for tf in self.temp_files[:]: |
| try: |
| if os.path.isdir(tf): |
| shutil.rmtree(tf, ignore_errors=True) |
| elif os.path.exists(tf): |
| os.unlink(tf) |
| cleaned += 1 |
| except Exception as e: |
| logger.warning(f"Failed to cleanup {tf}: {e}") |
| finally: |
| try: self.temp_files.remove(tf) |
| except Exception: pass |
| if cleaned: |
| logger.info(f"Cleaned {cleaned} temp paths") |
|
|
| def aggressive_cleanup(self): |
| logger.info("Aggressive cleanup...") |
| gc.collect() |
| if self.torch_available and self.cuda_available: |
| try: |
| import torch |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| self.cleanup_temp_files() |
| gc.collect() |
|
|
| @contextmanager |
| def mem_context(self, name="op"): |
| stats = self.get_memory_stats() |
| logger.info(f"Start {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
| try: |
| yield self |
| finally: |
| self.aggressive_cleanup() |
| stats = self.get_memory_stats() |
| logger.info(f"End {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
|
|
| memory_manager = MemoryManager() |
|
|
| |
| |
| |
| class SystemState: |
| def __init__(self): |
| self.torch_available = TORCH_AVAILABLE |
| self.cuda_available = CUDA_AVAILABLE |
| self.device = DEVICE |
| self.sam2_ready = False |
| self.matanyone_ready = False |
| self.sam2_error = None |
| self.matanyone_error = None |
|
|
| def status_text(self) -> str: |
| stats = memory_manager.get_memory_stats() |
| return ( |
| "=== SYSTEM STATUS ===\n" |
| f"PyTorch: {'✅' if self.torch_available else '❌'}\n" |
| f"CUDA: {'✅' if self.cuda_available else '❌'}\n" |
| f"Device: {self.device}\n" |
| f"SAM2: {'✅' if self.sam2_ready else ('❌' if self.sam2_error else '⏳')}\n" |
| f"MatAnyone: {'✅' if self.matanyone_ready else ('❌' if self.matanyone_error else '⏳')}\n\n" |
| "=== MEMORY ===\n" |
| f"CPU: {stats.cpu_percent:.1f}% ({stats.cpu_memory_mb:.1f} MB)\n" |
| f"GPU: {stats.gpu_memory_mb:.1f} MB (Reserved {stats.gpu_memory_reserved_mb:.1f} MB)\n" |
| f"Temp: {stats.temp_files_count} files ({stats.temp_files_size_mb:.1f} MB)\n" |
| ) |
|
|
| state = SystemState() |
|
|
| |
| |
| |
| class SAM2Handler: |
| def __init__(self): |
| self.predictor = None |
| self.initialized = False |
|
|
| def initialize(self) -> bool: |
| if not (TORCH_AVAILABLE and CUDA_AVAILABLE): |
| state.sam2_error = "SAM2 requires CUDA" |
| return False |
|
|
| with memory_manager.mem_context("SAM2 init"): |
| try: |
| _reset_hydra() |
| repo_path = ensure_repo("sam2", "https://github.com/facebookresearch/segment-anything-2.git") |
| if not repo_path: |
| state.sam2_error = "Clone failed" |
| return False |
|
|
| ckpt = CHECKPOINTS / "sam2.1_hiera_large.pt" |
| url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" |
| if not download_file(url, ckpt, "SAM2 Large"): |
| state.sam2_error = "SAM2 ckpt download failed" |
| return False |
|
|
| from hydra.core.global_hydra import GlobalHydra |
| from hydra import initialize_config_dir |
| from sam2.build_sam import build_sam2 |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
| config_dir = (repo_path / "sam2" / "configs").as_posix() |
| if GlobalHydra().is_initialized(): |
| GlobalHydra.instance().clear() |
| initialize_config_dir(config_dir=config_dir, version_base=None) |
|
|
| model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(ckpt), device="cuda") |
| self.predictor = SAM2ImagePredictor(model) |
|
|
| |
| test = np.zeros((64, 64, 3), dtype=np.uint8) |
| self.predictor.set_image(test) |
| masks, scores, _ = self.predictor.predict( |
| point_coords=np.array([[32, 32]]), |
| point_labels=np.ones(1, dtype=np.int64), |
| multimask_output=True, |
| ) |
| ok = masks is not None and len(masks) > 0 |
| self.initialized = ok |
| state.sam2_ready = ok |
| if not ok: |
| state.sam2_error = "SAM2 verify failed" |
| return ok |
|
|
| except Exception as e: |
| state.sam2_error = f"SAM2 init error: {e}" |
| return False |
|
|
| def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]: |
| if not self.initialized: |
| return None |
| with memory_manager.mem_context("SAM2 mask"): |
| try: |
| self.predictor.set_image(image_rgb) |
| h, w = image_rgb.shape[:2] |
| strategies = [ |
| np.array([[w // 2, h // 2]]), |
| np.array([[w // 2, h // 3]]), |
| np.array([[w // 2, h // 3], [w // 2, (2 * h) // 3]]), |
| ] |
| best, best_score = None, -1.0 |
| for pc in strategies: |
| masks, scores, _ = self.predictor.predict( |
| point_coords=pc, |
| point_labels=np.ones(len(pc), dtype=np.int64), |
| multimask_output=True, |
| ) |
| if masks is not None and len(masks) > 0: |
| i = int(np.argmax(scores)) |
| sc = float(scores[i]) |
| if sc > best_score: |
| best_score, best = sc, masks[i] |
|
|
| if best is None: |
| return None |
|
|
| mask_u8 = (best * 255).astype(np.uint8) |
| k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| mask_clean = cv2.morphologyEx(mask_u8, cv2.MORPH_CLOSE, k) |
| mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_OPEN, k) |
| mask_clean = cv2.GaussianBlur(mask_clean, (3, 3), 1.0) |
| return mask_clean |
| except Exception as e: |
| logger.error(f"SAM2 mask error: {e}") |
| return None |
|
|
| |
| |
| |
| class MatAnyoneHandler: |
| """ |
| FIXED MatAnyone handler using existing matanyone_fixed files |
| """ |
| def __init__(self): |
| self.core = None |
| self.initialized = False |
|
|
| |
| def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor": |
| """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch).""" |
| assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}" |
| t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() |
| return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
| def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor": |
| """mask_u8: HxW -> torch float (H,W) in [0,1] on DEVICE (no batch, no channel).""" |
| if mask_u8.shape[0] != h or mask_u8.shape[1] != w: |
| mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST) |
| prob = (mask_u8.astype(np.float32) / 255.0) |
| t = torch.from_numpy(prob).contiguous().float() |
| return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
| def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor": |
| """Optional: 1xHxW (channel-first, still unbatched).""" |
| if mask_u8.shape[0] != h or mask_u8.shape[1] != w: |
| mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST) |
| prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] |
| t = torch.from_numpy(prob).contiguous().float() |
| return t.to(DEVICE, non_blocking=CUDA_AVAILABLE) |
|
|
| def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray: |
| """ |
| Accepts torch / numpy / tuple(list) outputs. |
| Returns uint8 HxW (0..255). Squeezes common shapes down to HxW. |
| """ |
| if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1: |
| alpha_like = alpha_like[1] |
|
|
| if isinstance(alpha_like, torch.Tensor): |
| t = alpha_like.detach() |
| if t.is_cuda: |
| t = t.cpu() |
| a = t.float().clamp(0, 1).numpy() |
| else: |
| a = np.asarray(alpha_like, dtype=np.float32) |
| a = np.clip(a, 0, 1) |
|
|
| a = np.squeeze(a) |
| if a.ndim == 3 and a.shape[0] >= 1: |
| a = a[0] |
| if a.ndim != 2: |
| raise ValueError(f"Alpha must be HxW; got {a.shape}") |
|
|
| return np.clip(a * 255.0, 0, 255).astype(np.uint8) |
|
|
| def initialize(self) -> bool: |
| """ |
| FIXED MatAnyone initialization using existing matanyone_fixed files |
| """ |
| if not TORCH_AVAILABLE: |
| state.matanyone_error = "PyTorch required" |
| return False |
| |
| with memory_manager.mem_context("MatAnyone init"): |
| try: |
| |
| local_matanyone = BASE_DIR / "matanyone_fixed" |
| |
| if not local_matanyone.exists(): |
| state.matanyone_error = "matanyone_fixed directory not found" |
| return False |
| |
| |
| matanyone_str = str(local_matanyone) |
| if matanyone_str not in sys.path: |
| sys.path.insert(0, matanyone_str) |
| |
| |
| try: |
| from inference.inference_core import InferenceCore |
| from utils.get_default_model import get_matanyone_model |
| except Exception as e: |
| state.matanyone_error = f"Import error: {e}" |
| return False |
| |
| |
| ckpt = CHECKPOINTS / "matanyone.pth" |
| if not ckpt.exists(): |
| url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth" |
| if not download_file(url, ckpt, "MatAnyone"): |
| logger.warning("MatAnyone checkpoint download failed, using random weights") |
| |
| |
| net = get_matanyone_model(str(ckpt), device=DEVICE) |
| |
| if net is None: |
| state.matanyone_error = "Model creation failed" |
| return False |
| |
| |
| self.core = InferenceCore(net) |
| self.initialized = True |
| state.matanyone_ready = True |
| |
| logger.info("Fixed MatAnyone initialized successfully") |
| return True |
| |
| except Exception as e: |
| state.matanyone_error = f"MatAnyone init error: {e}" |
| logger.error(f"MatAnyone initialization failed: {e}") |
| return False |
|
|
| def _try_step_variants_seed(self, |
| img_chw_t: "torch.Tensor", |
| prob_hw_t: "torch.Tensor", |
| prob_1hw_t: "torch.Tensor"): |
| """ |
| Simplified step variants using fixed MatAnyone |
| """ |
| |
| try: |
| return self.core.step(img_chw_t, prob_hw_t) |
| except Exception as e: |
| try: |
| return self.core.step(img_chw_t, prob_1hw_t) |
| except Exception as e2: |
| |
| return self.core.step(img_chw_t) |
|
|
| def _try_step_variants_noseed(self, img_chw_t: "torch.Tensor"): |
| """ |
| Simplified noseed variants using fixed MatAnyone |
| """ |
| return self.core.step(img_chw_t) |
|
|
| |
| def process_video(self, input_path: str, mask_path: str, output_path: str) -> str: |
| """ |
| Produce a single-channel alpha mp4 matching input fps & size. |
| |
| First frame: pass a soft seed prob (~HW) alongside the image. |
| Remaining frames: call step(image) only. |
| """ |
| if not self.initialized or self.core is None: |
| raise RuntimeError("MatAnyone not initialized") |
|
|
| out_dir = Path(output_path) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| alpha_path = out_dir / "alpha.mp4" |
|
|
| cap = cv2.VideoCapture(input_path) |
| if not cap.isOpened(): |
| raise RuntimeError("Could not open input video") |
|
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| |
| seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
| if seed_mask is None: |
| cap.release() |
| raise RuntimeError("Seed mask read failed") |
|
|
| prob_hw_t = self._prob_hw_from_mask_u8(seed_mask, w, h) |
| prob_1hw_t = self._prob_1hw_from_mask_u8(seed_mask, w, h) |
|
|
| |
| tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}" |
| tmp_dir.mkdir(parents=True, exist_ok=True) |
| memory_manager.register_temp_file(str(tmp_dir)) |
|
|
| frame_idx = 0 |
|
|
| |
| ok, frame_bgr = cap.read() |
| if not ok or frame_bgr is None: |
| cap.release() |
| raise RuntimeError("Empty first frame") |
| frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
| img_chw_t = self._to_chw_float(frame_rgb01) |
|
|
| with torch.no_grad(): |
| out_prob = self._try_step_variants_seed( |
| img_chw_t, prob_hw_t, prob_1hw_t |
| ) |
|
|
| alpha_u8 = self._alpha_to_u8_hw(out_prob) |
| cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8) |
| frame_idx += 1 |
|
|
| |
| while True: |
| ok, frame_bgr = cap.read() |
| if not ok or frame_bgr is None: |
| break |
|
|
| frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
| img_chw_t = self._to_chw_float(frame_rgb01) |
|
|
| with torch.no_grad(): |
| out_prob = self._try_step_variants_noseed(img_chw_t) |
|
|
| alpha_u8 = self._alpha_to_u8_hw(out_prob) |
| cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8) |
| frame_idx += 1 |
|
|
| cap.release() |
|
|
| |
| list_file = tmp_dir / "list.txt" |
| with open(list_file, "w") as f: |
| for i in range(frame_idx): |
| f.write(f"file '{(tmp_dir / f'{i:06d}.png').as_posix()}'\n") |
|
|
| cmd = [ |
| "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
| "-f", "concat", "-safe", "0", |
| "-r", f"{fps:.6f}", |
| "-i", str(list_file), |
| "-vf", f"format=gray,scale={w}:{h}:flags=area", |
| "-pix_fmt", "yuv420p", |
| "-c:v", "libx264", "-preset", "medium", "-crf", "18", |
| str(alpha_path) |
| ] |
| subprocess.run(cmd, check=True) |
| return str(alpha_path) |
|
|
| |
| |
| |
| def _maybe_enable_xformers(pipe): |
| try: |
| pipe.enable_xformers_memory_efficient_attention() |
| except Exception: |
| pass |
|
|
| def _setup_memory_efficient_pipeline(pipe, require_gpu: bool): |
| _maybe_enable_xformers(pipe) |
| if not require_gpu: |
| try: |
| if hasattr(pipe, "enable_attention_slicing"): |
| pipe.enable_attention_slicing("auto") |
| if hasattr(pipe, "enable_model_cpu_offload"): |
| pipe.enable_model_cpu_offload() |
| if hasattr(pipe, "enable_sequential_cpu_offload"): |
| pipe.enable_sequential_cpu_offload() |
| except Exception: |
| pass |
|
|
| def generate_sdxl_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0, |
| seed:Optional[int]=None, require_gpu:bool=False) -> str: |
| if not TORCH_AVAILABLE: |
| raise RuntimeError("PyTorch required for SDXL") |
| with memory_manager.mem_context("SDXL background"): |
| try: |
| from diffusers import StableDiffusionXLPipeline |
| except ImportError as e: |
| raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
| if require_gpu and not CUDA_AVAILABLE: |
| raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
| device = "cuda" if CUDA_AVAILABLE else "cpu" |
| torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
| generator = torch.Generator(device=device) |
| if seed is None: |
| seed = random.randint(0, 2**31 - 1) |
| generator.manual_seed(int(seed)) |
|
|
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| torch_dtype=torch_dtype, |
| add_watermarker=False, |
| ).to(device) |
|
|
| _setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
| enhanced = f"{prompt}, professional studio lighting, high detail, clean composition" |
| img = pipe( |
| prompt=enhanced, |
| height=int(height), |
| width=int(width), |
| num_inference_steps=int(steps), |
| guidance_scale=float(guidance), |
| generator=generator |
| ).images[0] |
|
|
| out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
| img.save(out, quality=95, optimize=True) |
| memory_manager.register_temp_file(str(out)) |
| del pipe, img |
| return str(out) |
|
|
| def generate_playground_v25_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0, |
| seed:Optional[int]=None, require_gpu:bool=False) -> str: |
| if not TORCH_AVAILABLE: |
| raise RuntimeError("PyTorch required for Playground v2.5") |
| with memory_manager.mem_context("Playground v2.5 background"): |
| try: |
| from diffusers import DiffusionPipeline |
| except ImportError as e: |
| raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
| if require_gpu and not CUDA_AVAILABLE: |
| raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
| device = "cuda" if CUDA_AVAILABLE else "cpu" |
| torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
| generator = torch.Generator(device=device) |
| if seed is None: |
| seed = random.randint(0, 2**31 - 1) |
| generator.manual_seed(int(seed)) |
|
|
| repo_id = "playgroundai/playground-v2.5-1024px-aesthetic" |
| pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) |
| _setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
| enhanced = f"{prompt}, professional quality, soft light, minimal distractions" |
| img = pipe( |
| prompt=enhanced, |
| height=int(height), |
| width=int(width), |
| num_inference_steps=int(steps), |
| guidance_scale=float(guidance), |
| generator=generator |
| ).images[0] |
|
|
| out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
| img.save(out, quality=95, optimize=True) |
| memory_manager.register_temp_file(str(out)) |
| del pipe, img |
| return str(out) |
|
|
| def generate_sd15_background(width:int, height:int, prompt:str, steps:int=25, guidance:float=7.5, |
| seed:Optional[int]=None, require_gpu:bool=False) -> str: |
| if not TORCH_AVAILABLE: |
| raise RuntimeError("PyTorch required for SD 1.5") |
| with memory_manager.mem_context("SD1.5 background"): |
| try: |
| from diffusers import StableDiffusionPipeline |
| except ImportError as e: |
| raise RuntimeError("Install diffusers/transformers/accelerate") from e |
|
|
| if require_gpu and not CUDA_AVAILABLE: |
| raise RuntimeError("Force GPU enabled but CUDA not available") |
|
|
| device = "cuda" if CUDA_AVAILABLE else "cpu" |
| torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32 |
|
|
| generator = torch.Generator(device=device) |
| if seed is None: |
| seed = random.randint(0, 2**31 - 1) |
| generator.manual_seed(int(seed)) |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| torch_dtype=torch_dtype, |
| safety_checker=None, |
| requires_safety_checker=False |
| ).to(device) |
|
|
| _setup_memory_efficient_pipeline(pipe, require_gpu) |
|
|
| enhanced = f"{prompt}, professional background, clean composition" |
| img = pipe( |
| prompt=enhanced, |
| height=int(height), |
| width=int(width), |
| num_inference_steps=int(steps), |
| guidance_scale=float(guidance), |
| generator=generator |
| ).images[0] |
|
|
| out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed or 0:08d}.jpg" |
| img.save(out, quality=95, optimize=True) |
| memory_manager.register_temp_file(str(out)) |
| del pipe, img |
| return str(out) |
|
|
| def generate_openai_background(width:int, height:int, prompt:str, api_key:str, model:str="gpt-image-1") -> str: |
| if not api_key or not isinstance(api_key, str) or len(api_key) < 10: |
| raise RuntimeError("Missing or invalid OpenAI API key") |
| with memory_manager.mem_context("OpenAI background"): |
| target = "1024x1024" |
| url = "https://api.openai.com/v1/images/generations" |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
| body = {"model": model, "prompt": f"{prompt}, professional background, studio lighting, minimal distractions, high detail", |
| "size": target, "n": 1, "quality": "high"} |
| import requests |
| r = requests.post(url, headers=headers, data=json.dumps(body), timeout=120) |
| if r.status_code != 200: |
| raise RuntimeError(f"OpenAI API error: {r.status_code} {r.text}") |
| data = r.json() |
| b64 = data["data"][0]["b64_json"] |
| raw = base64.b64decode(b64) |
| tmp_png = TEMP_DIR / f"openai_raw_{int(time.time())}_{random.randint(1000,9999)}.png" |
| with open(tmp_png, "wb") as f: |
| f.write(raw) |
| img = Image.open(tmp_png).convert("RGB").resize((int(width), int(height)), Image.LANCZOS) |
| out = TEMP_DIR / f"openai_bg_{int(time.time())}_{random.randint(1000,9999)}.jpg" |
| img.save(out, quality=95, optimize=True) |
| try: os.unlink(tmp_png) |
| except Exception: pass |
| memory_manager.register_temp_file(str(out)) |
| return str(out) |
|
|
| def generate_ai_background_router(width:int, height:int, prompt:str, model:str="SDXL", |
| steps:int=30, guidance:float=7.0, seed:Optional[int]=None, |
| openai_key:Optional[str]=None, require_gpu:bool=False) -> str: |
| try: |
| if model == "OpenAI (gpt-image-1)": |
| if not openai_key: |
| raise RuntimeError("OpenAI API key not provided") |
| return generate_openai_background(width, height, prompt, openai_key, model="gpt-image-1") |
| elif model == "Playground v2.5": |
| return generate_playground_v25_background(width, height, prompt, steps, guidance, seed, require_gpu) |
| elif model == "SDXL": |
| return generate_sdxl_background(width, height, prompt, steps, guidance, seed, require_gpu) |
| else: |
| return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu) |
| except Exception as e: |
| logger.warning(f"{model} generation failed: {e}; falling back to SD1.5/gradient") |
| try: |
| return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu=False) |
| except Exception: |
| grad = _make_vertical_gradient(width, height, (235, 240, 245), (210, 220, 230)) |
| out = TEMP_DIR / f"bg_fallback_{int(time.time())}.jpg" |
| cv2.imwrite(str(out), grad) |
| memory_manager.register_temp_file(str(out)) |
| return str(out) |
|
|
| |
| |
| |
| class ChunkedVideoProcessor: |
| def __init__(self, chunk_size_frames: int = 60): |
| self.chunk_size = int(chunk_size_frames) |
|
|
| def _extract_chunk(self, video_path: str, start_frame: int, end_frame: int, fps: float) -> str: |
| chunk_path = str(TEMP_DIR / f"chunk_{start_frame}_{end_frame}_{random.randint(1000,9999)}.mp4") |
| start_time = start_frame / fps |
| duration = max(0.001, (end_frame - start_frame) / fps) |
| cmd = [ |
| "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
| "-ss", f"{start_time:.6f}", "-i", video_path, |
| "-t", f"{duration:.6f}", |
| "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", |
| "-c:v", "libx264", "-preset", "veryfast", "-crf", "20", |
| "-an", chunk_path |
| ] |
| subprocess.run(cmd, check=True) |
| return chunk_path |
|
|
| def _merge_chunks(self, chunk_paths: List[str], fps: float, width: int, height: int) -> str: |
| if not chunk_paths: |
| raise ValueError("No chunks to merge") |
| if len(chunk_paths) == 1: |
| return chunk_paths[0] |
| concat_file = TEMP_DIR / f"concat_{random.randint(1000,9999)}.txt" |
| with open(concat_file, "w") as f: |
| for c in chunk_paths: |
| f.write(f"file '{c}'\n") |
| out = TEMP_DIR / f"merged_{random.randint(1000,9999)}.mp4" |
| cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
| "-f", "concat", "-safe", "0", "-i", str(concat_file), |
| "-c", "copy", str(out)] |
| subprocess.run(cmd, check=True) |
| return str(out) |
|
|
| def process_video_chunks(self, video_path: str, processor_func, **kwargs) -> str: |
| cap = cv2.VideoCapture(video_path) |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| cap.release() |
|
|
| processed: List[str] = [] |
| for start in range(0, total, self.chunk_size): |
| end = min(start + self.chunk_size, total) |
| with memory_manager.mem_context(f"chunk {start}-{end}"): |
| ch = self._extract_chunk(video_path, start, end, fps) |
| memory_manager.register_temp_file(ch) |
| out = processor_func(ch, **kwargs) |
| memory_manager.register_temp_file(out) |
| processed.append(out) |
| return self._merge_chunks(processed, fps, width, height) |
|
|
| |
| |
| |
| def process_video_main( |
| video_path: str, |
| background_path: Optional[str] = None, |
| trim_duration: Optional[float] = None, |
| crf: int = 18, |
| preserve_audio_flag: bool = True, |
| placement: Optional[dict] = None, |
| use_chunked_processing: bool = False, |
| progress: gr.Progress = gr.Progress(track_tqdm=True), |
| ) -> Tuple[Optional[str], str]: |
|
|
| messages: List[str] = [] |
| with memory_manager.mem_context("Pipeline"): |
| try: |
| progress(0, desc="Initializing models") |
| sam2 = SAM2Handler() |
| matanyone = MatAnyoneHandler() |
|
|
| if not sam2.initialize(): |
| return None, f"SAM2 init failed: {state.sam2_error}" |
| if not matanyone.initialize(): |
| return None, f"MatAnyone init failed: {state.matanyone_error}" |
| messages.append("✅ SAM2 & MatAnyone initialized") |
|
|
| progress(0.1, desc="Preparing video") |
| input_video = video_path |
|
|
| |
| if trim_duration and float(trim_duration) > 0: |
| trimmed = TEMP_DIR / f"trimmed_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
| memory_manager.register_temp_file(str(trimmed)) |
| with VideoFileClip(video_path) as clip: |
| d = min(float(trim_duration), float(clip.duration or trim_duration)) |
| sub = clip.subclip(0, d) |
| write_video_h264(sub, str(trimmed), crf=int(crf)) |
| sub.close() |
| input_video = str(trimmed) |
| messages.append(f"✂️ Trimmed to {d:.1f}s") |
| else: |
| with VideoFileClip(video_path) as clip: |
| messages.append(f"🎞️ Full video: {clip.duration:.1f}s") |
|
|
| progress(0.2, desc="Creating SAM2 mask") |
| cap = cv2.VideoCapture(input_video) |
| ret, first_frame = cap.read() |
| cap.release() |
| if not ret or first_frame is None: |
| return None, "Could not read video" |
| h, w = first_frame.shape[:2] |
| rgb0 = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
| mask = sam2.create_mask(rgb0) |
| if mask is None: |
| return None, "SAM2 mask failed" |
|
|
| mask_path = TEMP_DIR / f"mask_{int(time.time())}_{random.randint(1000,9999)}.png" |
| memory_manager.register_temp_file(str(mask_path)) |
| cv2.imwrite(str(mask_path), mask) |
| messages.append("✅ Person mask created") |
|
|
| progress(0.35, desc="Matting video") |
| if use_chunked_processing: |
| chunker = ChunkedVideoProcessor(chunk_size_frames=60) |
| alpha_video = chunker.process_video_chunks( |
| input_video, |
| lambda chunk_path, **_k: matanyone.process_video( |
| input_path=chunk_path, |
| mask_path=str(mask_path), |
| output_path=str(TEMP_DIR / f"matanyone_chunk_{int(time.time())}_{random.randint(1000,9999)}") |
| ) |
| ) |
| memory_manager.register_temp_file(alpha_video) |
| else: |
| out_dir = TEMP_DIR / f"matanyone_out_{int(time.time())}_{random.randint(1000,9999)}" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| memory_manager.register_temp_file(str(out_dir)) |
| alpha_video = matanyone.process_video( |
| input_path=input_video, |
| mask_path=str(mask_path), |
| output_path=str(out_dir) |
| ) |
|
|
| if not alpha_video or not os.path.exists(alpha_video): |
| return None, "MatAnyone did not produce alpha video" |
| messages.append("✅ Alpha video generated") |
|
|
| progress(0.55, desc="Preparing background") |
| original_clip = VideoFileClip(input_video) |
| alpha_clip = VideoFileClip(alpha_video) |
|
|
| if background_path and os.path.exists(background_path): |
| messages.append("🖼️ Using background file") |
| bg_bgr = cv2.imread(background_path) |
| bg_bgr = cv2.resize(bg_bgr, (w, h)) |
| bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
| else: |
| messages.append("🖼️ Using gradient background") |
| grad = _make_vertical_gradient(w, h, (200, 205, 215), (160, 170, 190)) |
| bg_rgb = cv2.cvtColor(grad, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
| |
| placement = placement or {} |
| px = max(0.0, min(1.0, float(placement.get("x", 0.5)))) |
| py = max(0.0, min(1.0, float(placement.get("y", 0.75)))) |
| ps = max(0.3, min(2.0, float(placement.get("scale", 1.0)))) |
| feather_px = max(0, min(50, int(placement.get("feather", 3)))) |
| |
| |
| logger.info(f"POSITIONING DEBUG: px={px:.3f}, py={py:.3f}, ps={ps:.3f}, feather={feather_px}") |
| logger.info(f"VIDEO DIMENSIONS: {w}x{h}") |
| logger.info(f"TARGET CENTER: ({int(px * w)}, {int(py * h)})") |
|
|
| frame_count = 0 |
| def composite_frame(get_frame, t): |
| nonlocal frame_count |
| frame_count += 1 |
| |
| |
| frame = get_frame(t).astype(np.float32) / 255.0 |
| hh, ww = frame.shape[:2] |
| |
| |
| alpha_duration = getattr(alpha_clip, 'duration', None) |
| if alpha_duration and alpha_duration > 0: |
| |
| alpha_t = min(t, alpha_duration - 0.01) |
| alpha_t = max(0.0, alpha_t) |
| else: |
| alpha_t = 0.0 |
| |
| try: |
| a = alpha_clip.get_frame(alpha_t) |
| |
| if a.ndim == 3: |
| a = a[:, :, 0] |
| a = a.astype(np.float32) / 255.0 |
| |
| |
| if a.shape != (hh, ww): |
| logger.warning(f"Alpha size mismatch: {a.shape} vs {(hh, ww)}, resizing...") |
| a = cv2.resize(a, (ww, hh), interpolation=cv2.INTER_LINEAR) |
| |
| except Exception as e: |
| logger.error(f"Alpha frame error at t={t:.3f}: {e}") |
| return (bg_rgb * 255).astype(np.uint8) |
|
|
| |
| sw = max(1, round(ww * ps)) |
| sh = max(1, round(hh * ps)) |
| |
| |
| try: |
| fg_scaled = cv2.resize(frame, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR) |
| a_scaled = cv2.resize(a, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR) |
| except Exception as e: |
| logger.error(f"Scaling error: {e}") |
| return (bg_rgb * 255).astype(np.uint8) |
|
|
| |
| fg_canvas = np.zeros_like(frame, dtype=np.float32) |
| a_canvas = np.zeros((hh, ww), dtype=np.float32) |
|
|
| |
| cx = round(px * ww) |
| cy = round(py * hh) |
| |
| |
| x0 = cx - sw // 2 |
| y0 = cy - sh // 2 |
| |
| |
| if frame_count <= 3: |
| logger.info(f"FRAME {frame_count}: scaled_size=({sw}, {sh}), center=({cx}, {cy}), top_left=({x0}, {y0})") |
|
|
| |
| xs0 = max(0, x0) |
| ys0 = max(0, y0) |
| xs1 = min(ww, x0 + sw) |
| ys1 = min(hh, y0 + sh) |
| |
| |
| if xs1 <= xs0 or ys1 <= ys0: |
| if frame_count <= 3: |
| logger.warning(f"Subject outside bounds: dest=({xs0},{ys0})-({xs1},{ys1})") |
| return (bg_rgb * 255).astype(np.uint8) |
|
|
| |
| src_x0 = xs0 - x0 |
| src_y0 = ys0 - y0 |
| src_x1 = src_x0 + (xs1 - xs0) |
| src_y1 = src_y0 + (ys1 - ys0) |
| |
| |
| if (src_x1 > sw or src_y1 > sh or src_x0 < 0 or src_y0 < 0 or |
| src_x1 <= src_x0 or src_y1 <= src_y0): |
| if frame_count <= 3: |
| logger.error(f"Invalid source region: ({src_x0},{src_y0})-({src_x1},{src_y1}) for {sw}x{sh} scaled") |
| return (bg_rgb * 255).astype(np.uint8) |
|
|
| |
| try: |
| fg_canvas[ys0:ys1, xs0:xs1, :] = fg_scaled[src_y0:src_y1, src_x0:src_x1, :] |
| a_canvas[ys0:ys1, xs0:xs1] = a_scaled[src_y0:src_y1, src_x0:src_x1] |
| except Exception as e: |
| logger.error(f"Canvas placement failed: {e}") |
| logger.error(f"Dest: [{ys0}:{ys1}, {xs0}:{xs1}], Src: [{src_y0}:{src_y1}, {src_x0}:{src_x1}]") |
| return (bg_rgb * 255).astype(np.uint8) |
|
|
| |
| if feather_px > 0: |
| kernel_size = max(3, feather_px * 2 + 1) |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
| try: |
| a_canvas = cv2.GaussianBlur(a_canvas, (kernel_size, kernel_size), feather_px / 3.0) |
| except Exception as e: |
| logger.warning(f"Feathering failed: {e}") |
|
|
| |
| a3 = np.expand_dims(a_canvas, axis=2) |
| comp = a3 * fg_canvas + (1.0 - a3) * bg_rgb |
| result = np.clip(comp * 255, 0, 255).astype(np.uint8) |
| |
| return result |
|
|
| progress(0.7, desc="Compositing") |
| final_clip = original_clip.fl(composite_frame) |
|
|
| output_path = OUT_DIR / f"processed_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
| temp_video_path = TEMP_DIR / f"temp_video_{int(time.time())}_{random.randint(1000,9999)}.mp4" |
| memory_manager.register_temp_file(str(temp_video_path)) |
|
|
| write_video_h264(final_clip, str(temp_video_path), crf=int(crf)) |
| original_clip.close(); alpha_clip.close(); final_clip.close() |
|
|
| progress(0.85, desc="Merging audio") |
| if preserve_audio_flag: |
| success = run_ffmpeg([ |
| "-i", str(temp_video_path), |
| "-i", video_path, |
| "-map", "0:v:0", |
| "-map", "1:a:0?", |
| "-c:v", "copy", |
| "-c:a", "aac", |
| "-b:a", "192k", |
| "-shortest", |
| str(output_path) |
| ], fail_ok=True) |
| if success: |
| messages.append("🔊 Original audio preserved") |
| else: |
| shutil.copy2(str(temp_video_path), str(output_path)) |
| messages.append("⚠️ Audio merge failed, saved w/o audio") |
| else: |
| shutil.copy2(str(temp_video_path), str(output_path)) |
| messages.append("🔇 Saved without audio") |
|
|
| messages.append("✅ Done") |
| stats = memory_manager.get_memory_stats() |
| messages.append(f"📊 CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
| messages.append(f"🎯 Processed {frame_count} frames with placement ({px:.2f}, {py:.2f}) @ {ps:.2f}x scale") |
| progress(1.0, desc="Done") |
| return str(output_path), "\n".join(messages) |
|
|
| except Exception as e: |
| err = f"Processing failed: {str(e)}\n\n{traceback.format_exc()}" |
| return None, err |
|
|
| |
| |
| |
| def create_interface(): |
| def diag(): |
| return state.status_text() |
|
|
| def cleanup(): |
| memory_manager.aggressive_cleanup() |
| s = memory_manager.get_memory_stats() |
| return f"🧹 Cleanup\nCPU: {s.cpu_memory_mb:.1f}MB\nGPU: {s.gpu_memory_mb:.1f}MB\nTemp: {s.temp_files_count} files" |
|
|
| def preload(ai_model, openai_key, force_gpu, progress=gr.Progress()): |
| try: |
| progress(0, desc="Preloading...") |
| msg = "" |
| if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"): |
| try: |
| if ai_model == "SDXL": |
| _ = generate_sdxl_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
| elif ai_model == "Playground v2.5": |
| _ = generate_playground_v25_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
| else: |
| _ = generate_sd15_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu)) |
| msg += f"{ai_model} preloaded.\n" |
| except Exception as e: |
| msg += f"{ai_model} preload failed: {e}\n" |
|
|
| _reset_hydra() |
| s, m = SAM2Handler(), MatAnyoneHandler() |
| ok_s = s.initialize() |
| _reset_hydra() |
| ok_m = m.initialize() |
| progress(1.0, desc="Preload complete") |
| return f"✅ Preload\n{msg}SAM2: {'ready' if ok_s else 'failed'}\nMatAnyone: {'ready' if ok_m else 'failed'}" |
| except Exception as e: |
| return f"❌ Preload error: {e}" |
|
|
| def generate_background_safe(video_file, ai_prompt, ai_steps, ai_guidance, ai_seed, |
| ai_model, openai_key, force_gpu, progress=gr.Progress()): |
| if not video_file: |
| return None, "Upload a video first", gr.update(visible=False), None |
| with memory_manager.mem_context("Background generation"): |
| try: |
| video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) |
| if not os.path.exists(video_path): |
| return None, "Video not found", gr.update(visible=False), None |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return None, "Could not open video", gr.update(visible=False), None |
| ret, frame = cap.read() |
| cap.release() |
| if not ret or frame is None: |
| return None, "Could not read frame", gr.update(visible=False), None |
| h, w = int(frame.shape[0]), int(frame.shape[1]) |
|
|
| steps = max(1, min(50, int(ai_steps or 30))) |
| guidance = max(1.0, min(15.0, float(ai_guidance or 7.0))) |
| try: |
| seed_val = int(ai_seed) if ai_seed and str(ai_seed).strip() else None |
| except Exception: |
| seed_val = None |
|
|
| progress(0.1, desc=f"Generating {ai_model}") |
| bg_path = generate_ai_background_router( |
| width=w, height=h, prompt=str(ai_prompt or "professional office background").strip(), |
| model=str(ai_model or "SDXL"), steps=steps, guidance=guidance, |
| seed=seed_val, openai_key=openai_key, require_gpu=bool(force_gpu) |
| ) |
| progress(1.0, desc="Background ready") |
| if bg_path and os.path.exists(bg_path): |
| return bg_path, f"AI background generated with {ai_model}", gr.update(visible=True), bg_path |
| else: |
| return None, "No output file", gr.update(visible=False), None |
| except Exception as e: |
| logger.error(f"Background generation error: {e}") |
| return None, f"Background generation failed: {str(e)}", gr.update(visible=False), None |
|
|
| def approve_background(bg_path): |
| try: |
| if not bg_path or not (isinstance(bg_path, str) and os.path.exists(bg_path)): |
| return None, "Generate a background first", gr.update(visible=False) |
| ext = os.path.splitext(bg_path)[1].lower() or ".jpg" |
| safe_name = f"approved_{int(time.time())}_{random.randint(1000,9999)}{ext}" |
| dest = BACKGROUND_DIR / safe_name |
| shutil.copy2(bg_path, dest) |
| return str(dest), f"✅ Background approved → {dest.name}", gr.update(visible=False) |
| except Exception as e: |
| return None, f"⚠️ Approve failed: {e}", gr.update(visible=False) |
|
|
| css = """ |
| .gradio-container { font-size: 16px !important; } |
| label { font-size: 18px !important; font-weight: 600 !important; color: #2d3748 !important; } |
| .process-button { font-size: 20px !important; font-weight: 700 !important; padding: 16px 28px !important; } |
| .memory-info { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; } |
| """ |
|
|
| with gr.Blocks(title="Enhanced Video Background Replacement", theme=gr.themes.Soft(), css=css) as interface: |
| gr.Markdown("# 🎬 Enhanced Video Background Replacement") |
| gr.Markdown("_SAM2 + MatAnyone + AI Backgrounds — with strict tensor shapes & memory management_") |
|
|
| gr.HTML(f""" |
| <div class='memory-info'> |
| <strong>Device:</strong> {DEVICE} |
| <strong>PyTorch:</strong> {'✅' if TORCH_AVAILABLE else '❌'} |
| <strong>CUDA:</strong> {'✅' if CUDA_AVAILABLE else '❌'} |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| video_input = gr.Video(label="Input Video") |
|
|
| gr.Markdown("### Background") |
| bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"], |
| value="AI Generated", label="Background Method") |
|
|
| |
| with gr.Group(visible=False) as upload_group: |
| upload_img = gr.Image(label="Background Image", type="filepath") |
|
|
| |
| with gr.Group(visible=False) as gradient_group: |
| gradient_choice = gr.Dropdown(label="Gradient Style", |
| choices=list(GRADIENT_PRESETS.keys()), |
| value="Slate") |
|
|
| |
| with gr.Group(visible=True) as ai_group: |
| prompt_suggestions = gr.Dropdown(label="💡 Prompt Inspiration", |
| choices=AI_PROMPT_SUGGESTIONS, |
| value="Custom (write your own)") |
| ai_prompt = gr.Textbox(label="Background Description", |
| value="professional office background", lines=3) |
| ai_model = gr.Radio(["SDXL", "Playground v2.5", "SD 1.5 (fallback)", "OpenAI (gpt-image-1)"], |
| value="SDXL", label="AI Model") |
| with gr.Accordion("Connect services (optional)", open=False): |
| openai_api_key = gr.Textbox(label="OpenAI API Key", type="password", |
| placeholder="sk-... (kept only in this session)") |
| with gr.Row(): |
| ai_steps = gr.Slider(10, 50, value=30, step=1, label="Quality (steps)") |
| ai_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.1, label="Guidance") |
| ai_seed = gr.Number(label="Seed (optional)", precision=0) |
| force_gpu_ai = gr.Checkbox(value=True, label="Force GPU for AI background") |
| preload_btn = gr.Button("📦 Preload Models") |
| preload_status = gr.Textbox(label="Preload Status", lines=4) |
| generate_bg_btn = gr.Button("Generate AI Background", variant="primary") |
| ai_generated_bg = gr.Image(label="Generated Background", type="filepath") |
| approve_bg_btn = gr.Button("✅ Approve Background", visible=False) |
| approved_background_path = gr.State(value=None) |
| last_generated_bg = gr.State(value=None) |
| ai_status = gr.Textbox(label="Generation Status", lines=2) |
|
|
| gr.Markdown("### Processing") |
| with gr.Row(): |
| trim_enabled = gr.Checkbox(label="Trim Video", value=False) |
| trim_seconds = gr.Number(label="Trim Duration (seconds)", value=5, precision=1) |
| with gr.Row(): |
| crf_value = gr.Slider(0, 30, value=18, step=1, label="Quality (CRF - lower=better)") |
| audio_enabled = gr.Checkbox(label="Preserve Audio", value=True) |
| with gr.Row(): |
| use_chunked = gr.Checkbox(label="Use Chunked Processing", value=False) |
|
|
| gr.Markdown("### Subject Placement") |
| with gr.Row(): |
| place_x = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Horizontal") |
| place_y = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Vertical") |
| with gr.Row(): |
| place_scale = gr.Slider(0.3, 2.0, value=1.0, step=0.01, label="Scale") |
| place_feather = gr.Slider(0, 15, value=3, step=1, label="Edge feather (px)") |
|
|
| process_btn = gr.Button("🚀 Process Video", variant="primary", elem_classes=["process-button"]) |
|
|
| gr.Markdown("### System") |
| with gr.Row(): |
| diagnostics_btn = gr.Button("📊 System Diagnostics") |
| cleanup_btn = gr.Button("🧹 Memory Cleanup") |
| diagnostics_output = gr.Textbox(label="System Status", lines=10) |
|
|
| with gr.Column(scale=1): |
| output_video = gr.Video(label="Processed Video") |
| download_file = gr.File(label="Download Processed Video") |
| status_output = gr.Textbox(label="Processing Status", lines=20) |
|
|
| |
| def update_background_visibility(method): |
| return ( |
| gr.update(visible=(method == "Upload Image")), |
| gr.update(visible=(method == "Gradients")), |
| gr.update(visible=(method == "AI Generated")), |
| ) |
|
|
| def update_prompt_from_suggestion(suggestion): |
| if suggestion == "Custom (write your own)": |
| return gr.update(value="", placeholder="Describe the background you want...") |
| return gr.update(value=suggestion) |
|
|
| bg_method.change( |
| update_background_visibility, |
| inputs=[bg_method], |
| outputs=[upload_group, gradient_group, ai_group] |
| ) |
| prompt_suggestions.change(update_prompt_from_suggestion, inputs=[prompt_suggestions], outputs=[ai_prompt]) |
|
|
| preload_btn.click(preload, |
| inputs=[ai_model, openai_api_key, force_gpu_ai], |
| outputs=[preload_status], |
| show_progress=True |
| ) |
|
|
| generate_bg_btn.click( |
| generate_background_safe, |
| inputs=[video_input, ai_prompt, ai_steps, ai_guidance, ai_seed, ai_model, openai_api_key, force_gpu_ai], |
| outputs=[ai_generated_bg, ai_status, approve_bg_btn, last_generated_bg], |
| show_progress=True |
| ) |
| approve_bg_btn.click( |
| approve_background, |
| inputs=[ai_generated_bg], |
| outputs=[approved_background_path, ai_status, approve_bg_btn] |
| ) |
|
|
| diagnostics_btn.click(diag, outputs=[diagnostics_output]) |
| cleanup_btn.click(cleanup, outputs=[diagnostics_output]) |
|
|
| def process_video( |
| video_file, |
| bg_method, |
| upload_img, |
| gradient_choice, |
| approved_background_path, |
| last_generated_bg, |
| trim_enabled, trim_seconds, crf_value, audio_enabled, |
| use_chunked, |
| place_x, place_y, place_scale, place_feather, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| try: |
| if not video_file: |
| return None, None, "Please upload a video file" |
| video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) |
|
|
| |
| bg_path = None |
| try: |
| if bg_method == "Upload Image" and upload_img: |
| bg_path = upload_img if isinstance(upload_img, str) else getattr(upload_img, "name", None) |
| elif bg_method == "Gradients": |
| cap = cv2.VideoCapture(video_path) |
| ret, frame = cap.read(); cap.release() |
| if ret and frame is not None: |
| h, w = frame.shape[:2] |
| if gradient_choice in GRADIENT_PRESETS: |
| grad = _make_vertical_gradient(w, h, *GRADIENT_PRESETS[gradient_choice]) |
| tmp_bg = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, dir=TEMP_DIR).name |
| cv2.imwrite(tmp_bg, grad) |
| memory_manager.register_temp_file(tmp_bg) |
| bg_path = tmp_bg |
| else: |
| if approved_background_path: |
| bg_path = approved_background_path |
| elif last_generated_bg and isinstance(last_generated_bg, str) and os.path.exists(last_generated_bg): |
| bg_path = last_generated_bg |
| except Exception as e: |
| logger.error(f"Background setup error: {e}") |
| return None, None, f"Background setup failed: {str(e)}" |
|
|
| result_path, status = process_video_main( |
| video_path=video_path, |
| background_path=bg_path, |
| trim_duration=float(trim_seconds) if (trim_enabled and float(trim_seconds) > 0) else None, |
| crf=int(crf_value), |
| preserve_audio_flag=bool(audio_enabled), |
| placement=dict(x=float(place_x), y=float(place_y), scale=float(place_scale), feather=int(place_feather)), |
| use_chunked_processing=bool(use_chunked), |
| progress=progress, |
| ) |
|
|
| if result_path and os.path.exists(result_path): |
| return result_path, result_path, f"✅ Success\n\n{status}" |
| else: |
| return None, None, f"❌ Failed\n\n{status or 'Unknown error'}" |
| except Exception as e: |
| tb = traceback.format_exc() |
| return None, None, f"❌ Crash: {e}\n\n{tb}" |
|
|
| process_btn.click( |
| process_video, |
| inputs=[ |
| video_input, |
| bg_method, |
| upload_img, |
| gradient_choice, |
| approved_background_path, last_generated_bg, |
| trim_enabled, trim_seconds, crf_value, audio_enabled, |
| use_chunked, |
| place_x, place_y, place_scale, place_feather, |
| ], |
| outputs=[output_video, download_file, status_output], |
| show_progress=True |
| ) |
|
|
| return interface |
|
|
| |
| |
| |
| def main(): |
| logger.info("Starting Enhanced Background Replacement") |
| stats = memory_manager.get_memory_stats() |
| logger.info(f"Initial memory: CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB") |
| interface = create_interface() |
| interface.queue(max_size=3) |
| try: |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| inbrowser=False, |
| show_error=True |
| ) |
| finally: |
| logger.info("Shutting down - cleanup") |
| memory_manager.cleanup_temp_files() |
| memory_manager.aggressive_cleanup() |
|
|
| if __name__ == "__main__": |
| main() |