"""Sapiens2 body-part segmentation Gradio Space. Image → 29-class semantic segmentation. Renders an AnnotatedImage so the user can hover over each predicted body part to highlight it with its class name. """ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import tempfile from typing import List, Tuple import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image from huggingface_hub import hf_hub_download from sapiens.dense.models import SegEstimator, init_model # registers SegEstimator from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29 _ = SegEstimator # ----------------------------------------------------------------------------- # Config ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") SEG_MODELS = { "0.4B": { "repo": "facebook/sapiens2-seg-0.4b", "filename": "sapiens2_0.4b_seg.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py"), }, "0.8B": { "repo": "facebook/sapiens2-seg-0.8b", "filename": "sapiens2_0.8b_seg.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py"), }, "1B": { "repo": "facebook/sapiens2-seg-1b", "filename": "sapiens2_1b_seg.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_seg_shutterstock_goliath-1024x768.py"), }, "5B": { "repo": "facebook/sapiens2-seg-5b", "filename": "sapiens2_5b_seg.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_seg_shutterstock_goliath-1024x768.py"), }, } DEFAULT_SIZE = "1B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # All 29 named classes (Background is class 0). _CLASS_LABELS = {cid: meta["name"].replace("_", " ") for cid, meta in DOME_CLASSES_29.items()} _CLASS_COLORS_RGB = {cid: meta["color"] for cid, meta in DOME_CLASSES_29.items()} _CLASS_COLORS_HEX = { _CLASS_LABELS[cid]: "#{:02x}{:02x}{:02x}".format(*meta["color"]) for cid, meta in DOME_CLASSES_29.items() } # ----------------------------------------------------------------------------- # Model cache _seg_model_cache: dict = {} def _get_seg_model(size: str): if size not in _seg_model_cache: spec = SEG_MODELS[size] ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) model = init_model(spec["config"], ckpt, device=DEVICE) _seg_model_cache[size] = model return _seg_model_cache[size] print("[startup] pre-loading all seg sizes ...") for _size in SEG_MODELS: _get_seg_model(_size) print("[startup] ready.") # ----------------------------------------------------------------------------- # Inference (mirrors sapiens/dense/tools/vis/vis_seg.py) def _segment(image_bgr: np.ndarray, model) -> np.ndarray: h0, w0 = image_bgr.shape[:2] data = model.pipeline(dict(img=image_bgr)) # resize + pad data = model.data_preprocessor(data) # normalize + batch inputs = data["inputs"] # already (B, 3, H, W) with torch.no_grad(): logits = model(inputs) # (1, 29, H, W) logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) return logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32) # (H, W) def _label_map_to_annotations(label_map: np.ndarray) -> List[Tuple[np.ndarray, str]]: """Convert (H, W) class-id map → AnnotatedImage's [(bool_mask, label), ...] list. Includes all 29 named classes (Background as well) so the legend is complete. """ annotations: List[Tuple[np.ndarray, str]] = [] for cid in np.unique(label_map): cid = int(cid) if cid not in _CLASS_LABELS: continue mask = (label_map == cid) if not mask.any(): continue annotations.append((mask, _CLASS_LABELS[cid])) return annotations def _label_map_to_overlay(image_bgr: np.ndarray, label_map: np.ndarray, opacity: float) -> np.ndarray: """Static color overlay (BGR) — colors from the DOME palette, alpha-blended.""" palette = np.zeros((256, 3), dtype=np.uint8) for cid, rgb in _CLASS_COLORS_RGB.items(): palette[cid] = rgb[::-1] # RGB → BGR for cv2 color_mask = palette[label_map] return cv2.addWeighted(image_bgr, 1.0 - opacity, color_mask, opacity, 0) # ----------------------------------------------------------------------------- # Gradio handler @spaces.GPU(duration=120) def predict(image: Image.Image, size: str): if image is None: return None, None, None image_pil = image.convert("RGB") image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) model = _get_seg_model(size) label_map = _segment(image_bgr, model) # (H, W) annotations = _label_map_to_annotations(label_map) annotated = (image_pil, annotations) overlay_bgr = _label_map_to_overlay(image_bgr, label_map, 0.5) overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name np.save(npy_path, label_map.astype(np.uint8)) return annotated, Image.fromarray(overlay_rgb), npy_path # ----------------------------------------------------------------------------- # UI EXAMPLES = sorted( os.path.join(ASSETS_DIR, "images", n) for n in os.listdir(os.path.join(ASSETS_DIR, "images")) if n.lower().endswith((".jpg", ".jpeg", ".png")) ) CUSTOM_CSS = """ :root, body, .gradio-container, button, input, select, textarea, .gradio-container *:not(code):not(pre) { font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } #title { text-align: center; font-size: 44px; font-weight: 700; letter-spacing: -0.01em; margin: 28px 0 4px; background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 12px; color: #64748b; letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; font-weight: 500; } #tagline { text-align: center; font-size: 15px; color: #475569; max-width: 680px; margin: 4px auto 22px; line-height: 1.55; font-weight: 400; } #badges { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin: 0 0 32px; } .pill { display: inline-flex; align-items: center; gap: 6px; padding: 7px 14px; border-radius: 999px; background: #f1f5f9; color: #0f172a !important; font-size: 13px; font-weight: 500; letter-spacing: 0.01em; text-decoration: none !important; border: 1px solid #e2e8f0; transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } .pill:hover { background: #0f172a; color: #f8fafc !important; border-color: #0f172a; transform: translateY(-1px); } .pill svg { width: 14px; height: 14px; } /* AnnotatedImage hover behavior: - Default: every mask at 55% (RGB shows through) - Hover legend item OR mask region → that mask pops to 75% with a coloured glow, every OTHER mask fades down to 10% (dim, not invisible). Gradio toggles `.active` on the hovered mask and `.inactive` on the rest, so we just over-style those. */ #seg-out .mask { opacity: 0.55 !important; transition: opacity 200ms ease, filter 200ms ease; } #seg-out .mask.active { opacity: 0.78 !important; filter: brightness(1.18) drop-shadow(0 0 10px rgba(255,255,255,0.55)) drop-shadow(0 4px 16px rgba(0,0,0,0.40)) !important; } #seg-out .mask.inactive { opacity: 0.10 !important; } #seg-out .legend-item { cursor: pointer; } /* Legend: vertical column on the right of the image instead of horizontal below. */ #seg-out .container { flex-direction: row !important; align-items: stretch !important; gap: 12px; } #seg-out .image-container { flex: 1 1 auto; min-width: 0; } #seg-out .legend { flex: 0 0 180px; flex-direction: column !important; flex-wrap: nowrap !important; align-items: stretch; justify-content: flex-start; gap: 4px; padding: 8px 4px; max-height: 640px; overflow-y: auto; border-left: 1px solid var(--border-color-primary, #e2e8f0); } #seg-out .legend-item { font-size: 12px; font-weight: 500; padding: 4px 10px; width: 100%; cursor: pointer; border-radius: 6px; transition: transform 140ms ease, background 140ms ease; } #seg-out .legend-item:hover { transform: translateX(2px); } """ HEADER_HTML = """