Spaces:
Running on Zero
Running on Zero
| """Sapiens2 human-matting Gradio Space. | |
| Image -> soft alpha matte + pre-multiplied foreground. The primary output is | |
| an interactive slider comparing the predicted foreground on green against a | |
| thresholded black/white alpha mask. | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from sapiens.dense.models import MattingEstimator, init_model # registers MattingEstimator | |
| _ = MattingEstimator | |
| # ----------------------------------------------------------------------------- | |
| # Config | |
| ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") | |
| MATTING_MODEL = { | |
| "repo": "facebook/sapiens2-matting-1b", | |
| "filename": "sapiens2_1b_matting.safetensors", | |
| "config": os.path.join( | |
| CONFIGS_DIR, "sapiens2_1b_matting_gss_p3m_metasim-1024x768.py" | |
| ), | |
| } | |
| GREEN_BACKGROUND_RGB = np.array([0.0, 177.0 / 255.0, 64.0 / 255.0], dtype=np.float32) | |
| # ----------------------------------------------------------------------------- | |
| # Model cache | |
| _matting_model = None | |
| _matting_model_device = None | |
| def _get_matting_model(): | |
| global _matting_model, _matting_model_device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if _matting_model is None or _matting_model_device != device: | |
| ckpt = hf_hub_download( | |
| repo_id=MATTING_MODEL["repo"], filename=MATTING_MODEL["filename"] | |
| ) | |
| _matting_model = init_model(MATTING_MODEL["config"], ckpt, device=device) | |
| _matting_model_device = device | |
| return _matting_model | |
| print("[startup] Sapiens2-1B matting app ready; model loads on first GPU request.") | |
| # ----------------------------------------------------------------------------- | |
| # Inference helpers | |
| def _estimate_matting(image_bgr: np.ndarray, model) -> tuple[np.ndarray, np.ndarray]: | |
| h0, w0 = image_bgr.shape[:2] | |
| data = model.pipeline(dict(img=image_bgr)) | |
| data = model.data_preprocessor(data) | |
| inputs = data["inputs"] | |
| with torch.no_grad(): | |
| outputs = model(inputs) # 1 x 4 x H x W: [pre-multiplied fgr RGB, alpha] | |
| outputs = F.interpolate( | |
| outputs, | |
| size=(h0, w0), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| outputs = outputs.squeeze(0).float().cpu().numpy() | |
| fgr_rgb = outputs[:3].clip(0.0, 1.0).transpose(1, 2, 0) | |
| alpha = outputs[3].clip(0.0, 1.0) | |
| return fgr_rgb, alpha | |
| def _green_background(height: int, width: int) -> np.ndarray: | |
| return np.broadcast_to(GREEN_BACKGROUND_RGB, (height, width, 3)) | |
| def _composite( | |
| fgr_rgb: np.ndarray, alpha: np.ndarray, background: np.ndarray | |
| ) -> np.ndarray: | |
| return (fgr_rgb + (1.0 - alpha[..., None]) * background).clip(0.0, 1.0) | |
| def _binary_alpha_rgb(alpha: np.ndarray) -> np.ndarray: | |
| mask = (alpha >= 0.5).astype(np.float32) | |
| return np.repeat(mask[..., None], 3, axis=2) | |
| def _straight_rgba(fgr_rgb: np.ndarray, alpha: np.ndarray) -> np.ndarray: | |
| straight = np.zeros_like(fgr_rgb) | |
| valid = alpha > 1e-4 | |
| straight[valid] = (fgr_rgb[valid] / alpha[valid][:, None]).clip(0.0, 1.0) | |
| rgba = np.dstack([straight, alpha]) | |
| return (rgba * 255.0).round().clip(0, 255).astype(np.uint8) | |
| def _to_pil_rgb(image: np.ndarray) -> Image.Image: | |
| return Image.fromarray((image * 255.0).round().clip(0, 255).astype(np.uint8)) | |
| def _save_png(image: Image.Image) -> str: | |
| path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name | |
| image.save(path) | |
| return path | |
| def _save_alpha(alpha: np.ndarray) -> str: | |
| path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name | |
| np.save(path, alpha.astype(np.float32)) | |
| return path | |
| # ----------------------------------------------------------------------------- | |
| # Gradio handler | |
| def predict(image: Image.Image): | |
| if image is None: | |
| return None, None, None, None, None | |
| image_pil = image.convert("RGB") | |
| image_rgb = np.array(image_pil) | |
| image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) | |
| height, width = image_rgb.shape[:2] | |
| model = _get_matting_model() | |
| fgr_rgb, alpha = _estimate_matting(image_bgr, model) | |
| bg = _green_background(height, width) | |
| composite = _composite(fgr_rgb, alpha, bg) | |
| composite_pil = _to_pil_rgb(composite) | |
| alpha_pil = _to_pil_rgb(_binary_alpha_rgb(alpha)) | |
| rgba_pil = Image.fromarray(_straight_rgba(fgr_rgb, alpha)) | |
| alpha_path = _save_alpha(alpha) | |
| rgba_path = _save_png(rgba_pil) | |
| return (composite_pil, alpha_pil), alpha_pil, rgba_pil, alpha_path, rgba_path | |
| # ----------------------------------------------------------------------------- | |
| # UI | |
| EXAMPLES = sorted( | |
| os.path.join(ASSETS_DIR, "images", n) | |
| for n in os.listdir(os.path.join(ASSETS_DIR, "images")) | |
| if n.lower().endswith((".jpg", ".jpeg", ".png")) | |
| ) | |
| DEFAULT_EXAMPLE = ( | |
| EXAMPLES[2] if len(EXAMPLES) >= 3 else (EXAMPLES[0] if EXAMPLES else None) | |
| ) | |
| CUSTOM_CSS = """ | |
| :root, body, .gradio-container, button, input, select, textarea, | |
| .gradio-container *:not(code):not(pre) { | |
| font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| } | |
| #title { | |
| text-align: center; | |
| font-size: 44px; | |
| font-weight: 700; | |
| letter-spacing: 0; | |
| margin: 28px 0 4px; | |
| background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| #subtitle { | |
| text-align: center; | |
| font-size: 12px; | |
| color: #64748b; | |
| letter-spacing: 0; | |
| margin: 0 0 14px; | |
| text-transform: uppercase; | |
| font-weight: 600; | |
| } | |
| #tagline { | |
| text-align: center; | |
| font-size: 15px; | |
| color: #475569; | |
| max-width: 720px; | |
| margin: 4px auto 22px; | |
| line-height: 1.55; | |
| font-weight: 400; | |
| } | |
| #badges { | |
| display: flex; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| margin: 0 0 32px; | |
| } | |
| .pill { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| padding: 7px 14px; | |
| border-radius: 999px; | |
| background: #f8fafc; | |
| color: #0f172a !important; | |
| font-size: 13px; | |
| font-weight: 550; | |
| letter-spacing: 0; | |
| text-decoration: none !important; | |
| border: 1px solid #e2e8f0; | |
| transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; | |
| } | |
| .pill:hover { | |
| background: #0f172a; | |
| color: #f8fafc !important; | |
| border-color: #0f172a; | |
| transform: translateY(-1px); | |
| } | |
| .pill svg { width: 14px; height: 14px; } | |
| #matting-slider .image-slider { border-radius: 8px; overflow: hidden; } | |
| """ | |
| HEADER_HTML = """ | |
| <div id="title">Sapiens2: Matting</div> | |
| <div id="subtitle">ICLR 2026</div> | |
| <div id="tagline">Soft human alpha matting and foreground extraction from a single image.</div> | |
| <div id="badges"> | |
| <a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg> | |
| Code | |
| </a> | |
| <a class="pill" href="https://huggingface.co/facebook/sapiens2-matting-1b" target="_blank" rel="noopener"> | |
| 🤗 Model | |
| </a> | |
| <a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg> | |
| Paper | |
| </a> | |
| <a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg> | |
| Project | |
| </a> | |
| </div> | |
| """ | |
| with gr.Blocks(title="Sapiens2 Matting", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| gr.HTML(HEADER_HTML) | |
| with gr.Row(equal_height=True): | |
| inp = gr.Image( | |
| label="Input", | |
| type="pil", | |
| height=640, | |
| value=DEFAULT_EXAMPLE, | |
| ) | |
| out_slider = gr.ImageSlider( | |
| label="Output", | |
| type="pil", | |
| height=640, | |
| max_height=640, | |
| slider_position=50, | |
| elem_id="matting-slider", | |
| ) | |
| run = gr.Button("Run", variant="primary", size="lg") | |
| gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16) | |
| with gr.Accordion("Alpha + Foreground", open=False): | |
| with gr.Row(equal_height=True): | |
| alpha_img = gr.Image(label="Binary Alpha Mask", type="pil", height=360) | |
| rgba_img = gr.Image(label="Foreground PNG", type="pil", height=360) | |
| with gr.Row(): | |
| alpha_file = gr.File(label="Raw alpha (.npy float32)") | |
| rgba_file = gr.File(label="Foreground with transparency (.png)") | |
| run.click( | |
| predict, | |
| inputs=[inp], | |
| outputs=[out_slider, alpha_img, rgba_img, alpha_file, rgba_file], | |
| ) | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| demo.launch(share=False) | |