"""Sapiens2 human-matting Gradio Space. Image -> soft alpha matte + pre-multiplied foreground. The primary output is an interactive slider comparing the predicted foreground on green against a thresholded black/white alpha mask. """ import os import sys import tempfile sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from PIL import Image from sapiens.dense.models import MattingEstimator, init_model # registers MattingEstimator _ = MattingEstimator # ----------------------------------------------------------------------------- # Config ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") MATTING_MODEL = { "repo": "facebook/sapiens2-matting-1b", "filename": "sapiens2_1b_matting.safetensors", "config": os.path.join( CONFIGS_DIR, "sapiens2_1b_matting_gss_p3m_metasim-1024x768.py" ), } GREEN_BACKGROUND_RGB = np.array([0.0, 177.0 / 255.0, 64.0 / 255.0], dtype=np.float32) # ----------------------------------------------------------------------------- # Model cache _matting_model = None _matting_model_device = None def _get_matting_model(): global _matting_model, _matting_model_device device = "cuda" if torch.cuda.is_available() else "cpu" if _matting_model is None or _matting_model_device != device: ckpt = hf_hub_download( repo_id=MATTING_MODEL["repo"], filename=MATTING_MODEL["filename"] ) _matting_model = init_model(MATTING_MODEL["config"], ckpt, device=device) _matting_model_device = device return _matting_model print("[startup] Sapiens2-1B matting app ready; model loads on first GPU request.") # ----------------------------------------------------------------------------- # Inference helpers def _estimate_matting(image_bgr: np.ndarray, model) -> tuple[np.ndarray, np.ndarray]: h0, w0 = image_bgr.shape[:2] data = model.pipeline(dict(img=image_bgr)) data = model.data_preprocessor(data) inputs = data["inputs"] with torch.no_grad(): outputs = model(inputs) # 1 x 4 x H x W: [pre-multiplied fgr RGB, alpha] outputs = F.interpolate( outputs, size=(h0, w0), mode="bilinear", align_corners=False, ) outputs = outputs.squeeze(0).float().cpu().numpy() fgr_rgb = outputs[:3].clip(0.0, 1.0).transpose(1, 2, 0) alpha = outputs[3].clip(0.0, 1.0) return fgr_rgb, alpha def _green_background(height: int, width: int) -> np.ndarray: return np.broadcast_to(GREEN_BACKGROUND_RGB, (height, width, 3)) def _composite( fgr_rgb: np.ndarray, alpha: np.ndarray, background: np.ndarray ) -> np.ndarray: return (fgr_rgb + (1.0 - alpha[..., None]) * background).clip(0.0, 1.0) def _binary_alpha_rgb(alpha: np.ndarray) -> np.ndarray: mask = (alpha >= 0.5).astype(np.float32) return np.repeat(mask[..., None], 3, axis=2) def _straight_rgba(fgr_rgb: np.ndarray, alpha: np.ndarray) -> np.ndarray: straight = np.zeros_like(fgr_rgb) valid = alpha > 1e-4 straight[valid] = (fgr_rgb[valid] / alpha[valid][:, None]).clip(0.0, 1.0) rgba = np.dstack([straight, alpha]) return (rgba * 255.0).round().clip(0, 255).astype(np.uint8) def _to_pil_rgb(image: np.ndarray) -> Image.Image: return Image.fromarray((image * 255.0).round().clip(0, 255).astype(np.uint8)) def _save_png(image: Image.Image) -> str: path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name image.save(path) return path def _save_alpha(alpha: np.ndarray) -> str: path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name np.save(path, alpha.astype(np.float32)) return path # ----------------------------------------------------------------------------- # Gradio handler @spaces.GPU(duration=120) def predict(image: Image.Image): if image is None: return None, None, None, None, None image_pil = image.convert("RGB") image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) height, width = image_rgb.shape[:2] model = _get_matting_model() fgr_rgb, alpha = _estimate_matting(image_bgr, model) bg = _green_background(height, width) composite = _composite(fgr_rgb, alpha, bg) composite_pil = _to_pil_rgb(composite) alpha_pil = _to_pil_rgb(_binary_alpha_rgb(alpha)) rgba_pil = Image.fromarray(_straight_rgba(fgr_rgb, alpha)) alpha_path = _save_alpha(alpha) rgba_path = _save_png(rgba_pil) return (composite_pil, alpha_pil), alpha_pil, rgba_pil, alpha_path, rgba_path # ----------------------------------------------------------------------------- # UI EXAMPLES = sorted( os.path.join(ASSETS_DIR, "images", n) for n in os.listdir(os.path.join(ASSETS_DIR, "images")) if n.lower().endswith((".jpg", ".jpeg", ".png")) ) DEFAULT_EXAMPLE = ( EXAMPLES[2] if len(EXAMPLES) >= 3 else (EXAMPLES[0] if EXAMPLES else None) ) CUSTOM_CSS = """ :root, body, .gradio-container, button, input, select, textarea, .gradio-container *:not(code):not(pre) { font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } #title { text-align: center; font-size: 44px; font-weight: 700; letter-spacing: 0; margin: 28px 0 4px; background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 12px; color: #64748b; letter-spacing: 0; margin: 0 0 14px; text-transform: uppercase; font-weight: 600; } #tagline { text-align: center; font-size: 15px; color: #475569; max-width: 720px; margin: 4px auto 22px; line-height: 1.55; font-weight: 400; } #badges { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin: 0 0 32px; } .pill { display: inline-flex; align-items: center; gap: 6px; padding: 7px 14px; border-radius: 999px; background: #f8fafc; color: #0f172a !important; font-size: 13px; font-weight: 550; letter-spacing: 0; text-decoration: none !important; border: 1px solid #e2e8f0; transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } .pill:hover { background: #0f172a; color: #f8fafc !important; border-color: #0f172a; transform: translateY(-1px); } .pill svg { width: 14px; height: 14px; } #matting-slider .image-slider { border-radius: 8px; overflow: hidden; } """ HEADER_HTML = """