Spaces:
Running on Zero
Running on Zero
| """Sapiens2 body-part segmentation Gradio Space. | |
| Image → 29-class semantic segmentation. Renders an AnnotatedImage so the user | |
| can hover over each predicted body part to highlight it with its class name. | |
| """ | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import tempfile | |
| from typing import List, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from sapiens.dense.models import SegEstimator, init_model # registers SegEstimator | |
| from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29 | |
| _ = SegEstimator | |
| # ----------------------------------------------------------------------------- | |
| # Config | |
| ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") | |
| SEG_MODELS = { | |
| "0.4B": { | |
| "repo": "facebook/sapiens2-seg-0.4b", | |
| "filename": "sapiens2_0.4b_seg.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py"), | |
| }, | |
| "0.8B": { | |
| "repo": "facebook/sapiens2-seg-0.8b", | |
| "filename": "sapiens2_0.8b_seg.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py"), | |
| }, | |
| "1B": { | |
| "repo": "facebook/sapiens2-seg-1b", | |
| "filename": "sapiens2_1b_seg.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_seg_shutterstock_goliath-1024x768.py"), | |
| }, | |
| "5B": { | |
| "repo": "facebook/sapiens2-seg-5b", | |
| "filename": "sapiens2_5b_seg.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_seg_shutterstock_goliath-1024x768.py"), | |
| }, | |
| } | |
| DEFAULT_SIZE = "1B" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # All 29 named classes (Background is class 0). | |
| _CLASS_LABELS = {cid: meta["name"].replace("_", " ") for cid, meta in DOME_CLASSES_29.items()} | |
| _CLASS_COLORS_RGB = {cid: meta["color"] for cid, meta in DOME_CLASSES_29.items()} | |
| _CLASS_COLORS_HEX = { | |
| _CLASS_LABELS[cid]: "#{:02x}{:02x}{:02x}".format(*meta["color"]) | |
| for cid, meta in DOME_CLASSES_29.items() | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # Model cache | |
| _seg_model_cache: dict = {} | |
| def _get_seg_model(size: str): | |
| if size not in _seg_model_cache: | |
| spec = SEG_MODELS[size] | |
| ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) | |
| model = init_model(spec["config"], ckpt, device=DEVICE) | |
| _seg_model_cache[size] = model | |
| return _seg_model_cache[size] | |
| print("[startup] pre-loading all seg sizes ...") | |
| for _size in SEG_MODELS: | |
| _get_seg_model(_size) | |
| print("[startup] ready.") | |
| # ----------------------------------------------------------------------------- | |
| # Inference (mirrors sapiens/dense/tools/vis/vis_seg.py) | |
| def _segment(image_bgr: np.ndarray, model) -> np.ndarray: | |
| h0, w0 = image_bgr.shape[:2] | |
| data = model.pipeline(dict(img=image_bgr)) # resize + pad | |
| data = model.data_preprocessor(data) # normalize + batch | |
| inputs = data["inputs"] # already (B, 3, H, W) | |
| with torch.no_grad(): | |
| logits = model(inputs) # (1, 29, H, W) | |
| logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) | |
| return logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32) # (H, W) | |
| def _label_map_to_annotations(label_map: np.ndarray) -> List[Tuple[np.ndarray, str]]: | |
| """Convert (H, W) class-id map → AnnotatedImage's [(bool_mask, label), ...] list. | |
| Includes all 29 named classes (Background as well) so the legend is complete. | |
| """ | |
| annotations: List[Tuple[np.ndarray, str]] = [] | |
| for cid in np.unique(label_map): | |
| cid = int(cid) | |
| if cid not in _CLASS_LABELS: | |
| continue | |
| mask = (label_map == cid) | |
| if not mask.any(): | |
| continue | |
| annotations.append((mask, _CLASS_LABELS[cid])) | |
| return annotations | |
| def _label_map_to_overlay(image_bgr: np.ndarray, label_map: np.ndarray, opacity: float) -> np.ndarray: | |
| """Static color overlay (BGR) — colors from the DOME palette, alpha-blended.""" | |
| palette = np.zeros((256, 3), dtype=np.uint8) | |
| for cid, rgb in _CLASS_COLORS_RGB.items(): | |
| palette[cid] = rgb[::-1] # RGB → BGR for cv2 | |
| color_mask = palette[label_map] | |
| return cv2.addWeighted(image_bgr, 1.0 - opacity, color_mask, opacity, 0) | |
| # ----------------------------------------------------------------------------- | |
| # Gradio handler | |
| def predict(image: Image.Image, size: str): | |
| if image is None: | |
| return None, None, None | |
| image_pil = image.convert("RGB") | |
| image_rgb = np.array(image_pil) | |
| image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) | |
| model = _get_seg_model(size) | |
| label_map = _segment(image_bgr, model) # (H, W) | |
| annotations = _label_map_to_annotations(label_map) | |
| annotated = (image_pil, annotations) | |
| overlay_bgr = _label_map_to_overlay(image_bgr, label_map, 0.5) | |
| overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) | |
| npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name | |
| np.save(npy_path, label_map.astype(np.uint8)) | |
| return annotated, Image.fromarray(overlay_rgb), npy_path | |
| # ----------------------------------------------------------------------------- | |
| # UI | |
| EXAMPLES = sorted( | |
| os.path.join(ASSETS_DIR, "images", n) | |
| for n in os.listdir(os.path.join(ASSETS_DIR, "images")) | |
| if n.lower().endswith((".jpg", ".jpeg", ".png")) | |
| ) | |
| CUSTOM_CSS = """ | |
| :root, body, .gradio-container, button, input, select, textarea, | |
| .gradio-container *:not(code):not(pre) { | |
| font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| } | |
| #title { text-align: center; font-size: 44px; font-weight: 700; | |
| letter-spacing: -0.01em; margin: 28px 0 4px; | |
| background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| background-clip: text; } | |
| #subtitle { text-align: center; font-size: 12px; color: #64748b; | |
| letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; | |
| font-weight: 500; } | |
| #tagline { text-align: center; font-size: 15px; color: #475569; | |
| max-width: 680px; margin: 4px auto 22px; line-height: 1.55; | |
| font-weight: 400; } | |
| #badges { display: flex; justify-content: center; flex-wrap: wrap; | |
| gap: 8px; margin: 0 0 32px; } | |
| .pill { display: inline-flex; align-items: center; gap: 6px; | |
| padding: 7px 14px; border-radius: 999px; | |
| background: #f1f5f9; color: #0f172a !important; | |
| font-size: 13px; font-weight: 500; letter-spacing: 0.01em; | |
| text-decoration: none !important; border: 1px solid #e2e8f0; | |
| transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } | |
| .pill:hover { background: #0f172a; color: #f8fafc !important; | |
| border-color: #0f172a; transform: translateY(-1px); } | |
| .pill svg { width: 14px; height: 14px; } | |
| /* AnnotatedImage hover behavior: | |
| - Default: every mask at 55% (RGB shows through) | |
| - Hover legend item OR mask region → that mask pops to 75% with a coloured glow, | |
| every OTHER mask fades down to 10% (dim, not invisible). | |
| Gradio toggles `.active` on the hovered mask and `.inactive` on the rest, so we | |
| just over-style those. */ | |
| #seg-out .mask { opacity: 0.55 !important; | |
| transition: opacity 200ms ease, filter 200ms ease; } | |
| #seg-out .mask.active { opacity: 0.78 !important; | |
| filter: brightness(1.18) | |
| drop-shadow(0 0 10px rgba(255,255,255,0.55)) | |
| drop-shadow(0 4px 16px rgba(0,0,0,0.40)) !important; } | |
| #seg-out .mask.inactive { opacity: 0.10 !important; } | |
| #seg-out .legend-item { cursor: pointer; } | |
| /* Legend: vertical column on the right of the image instead of horizontal below. */ | |
| #seg-out .container { flex-direction: row !important; align-items: stretch !important; gap: 12px; } | |
| #seg-out .image-container { flex: 1 1 auto; min-width: 0; } | |
| #seg-out .legend { | |
| flex: 0 0 180px; flex-direction: column !important; flex-wrap: nowrap !important; | |
| align-items: stretch; justify-content: flex-start; | |
| gap: 4px; padding: 8px 4px; | |
| max-height: 640px; overflow-y: auto; | |
| border-left: 1px solid var(--border-color-primary, #e2e8f0); | |
| } | |
| #seg-out .legend-item { font-size: 12px; font-weight: 500; | |
| padding: 4px 10px; width: 100%; cursor: pointer; | |
| border-radius: 6px; | |
| transition: transform 140ms ease, background 140ms ease; } | |
| #seg-out .legend-item:hover { transform: translateX(2px); } | |
| """ | |
| HEADER_HTML = """ | |
| <div id="title">Sapiens2: Segmentation</div> | |
| <div id="subtitle">ICLR 2026</div> | |
| <div id="badges"> | |
| <a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg> | |
| Code | |
| </a> | |
| <a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener"> | |
| 🤗 Models | |
| </a> | |
| <a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg> | |
| Paper | |
| </a> | |
| <a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg> | |
| Project | |
| </a> | |
| </div> | |
| """ | |
| with gr.Blocks(title="Sapiens2 Seg", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| gr.HTML(HEADER_HTML) | |
| with gr.Row(equal_height=True): | |
| inp = gr.Image(label="Input", type="pil", height=640) | |
| out_annot = gr.AnnotatedImage( | |
| label="Outputs (29 classes)", | |
| color_map=_CLASS_COLORS_HEX, | |
| height=640, | |
| show_legend=True, | |
| elem_id="seg-out", | |
| ) | |
| with gr.Row(): | |
| size = gr.Radio( | |
| choices=list(SEG_MODELS.keys()), | |
| value=DEFAULT_SIZE, | |
| label="Model", | |
| scale=4, | |
| ) | |
| run = gr.Button("Run", variant="primary", size="lg", scale=1) | |
| gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16) | |
| with gr.Accordion("Original Res + Raw Labels", open=False): | |
| out_img = gr.Image(label="Color overlay (PNG)", type="pil") | |
| out_npy = gr.File(label="Raw labels (.npy uint8, class indices 0–28)") | |
| run.click(predict, inputs=[inp, size], outputs=[out_annot, out_img, out_npy]) | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| demo.launch(share=False) | |