bingyic's picture
Update app.py
5f0588f verified
"""TIPS Feature Explorer (GPU) β€” Hugging Face Space demo with ZeroGPU."""
import colorsys
import gradio as gr
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from fast_pytorch_kmeans import KMeans as TorchKMeans
from sklearn.decomposition import PCA
from torchvision import transforms
from transformers import AutoModel
# ── Constants ───────────────────────────────────────────────────────────────
DEFAULT_IMAGE_SIZE = 896
PATCH_SIZE = 14
RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792]
ZEROSEG_IMAGE_SIZE = 1372
MAX_LEN = 64
# HF model repos (DPT repos include the backbone)
VARIANTS = {
"TIPS v2 β€” B/14": "google/tipsv2-b14-dpt",
"TIPS v2 β€” L/14": "google/tipsv2-l14-dpt",
"TIPS v2 β€” SO400m/14": "google/tipsv2-so400m14-dpt",
"TIPS v2 β€” g/14": "google/tipsv2-g14-dpt",
}
DEFAULT_VARIANT = "TIPS v2 β€” L/14"
def _device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Pascal Context (59 classes) ─────────────────────────────────────────────
# TCL prompt templates (from the Scenic zero-shot seg evaluator).
TCL_PROMPTS = [
"itap of a {}.",
"a bad photo of a {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}.",
"a photo of many {}.",
"a photo of {}s.",
]
# Class names (excluding background at index 0) from segmentation_dataset_info.py
PASCAL_CONTEXT_CLASSES = (
"aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird",
"boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat",
"ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain",
"dog", "door", "fence", "floor", "flower", "food", "grass", "ground",
"horse", "keyboard", "light", "motorbike", "mountain", "mouse",
"person", "plate", "platform", "pottedplant", "road", "rock", "sheep",
"shelves", "sidewalk", "sign", "sky", "snow", "sofa", "table", "track",
"train", "tree", "truck", "tvmonitor", "wall", "water", "window",
"wood",
)
ADE20K_CLASSES = (
'wall', 'building', 'sky', 'floor', 'tree',
'ceiling', 'road', 'bed', 'windowpane', 'grass',
'cabinet', 'sidewalk', 'person', 'earth', 'door',
'table', 'mountain', 'plant', 'curtain', 'chair',
'car', 'water', 'painting', 'sofa', 'shelf',
'house', 'sea', 'mirror', 'rug', 'field',
'armchair', 'seat', 'fence', 'desk', 'rock',
'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',
'base', 'box', 'column', 'signboard', 'chest_of_drawers',
'counter', 'sand', 'sink', 'skyscraper', 'fireplace',
'refrigerator', 'grandstand', 'path', 'stairs', 'runway',
'case', 'pool_table', 'pillow', 'screen_door', 'stairway',
'river', 'bridge', 'bookcase', 'blind', 'coffee_table',
'toilet', 'flower', 'book', 'hill', 'bench',
'countertop', 'stove', 'palm', 'kitchen_island', 'computer',
'swivel_chair', 'boat', 'bar', 'arcade_machine', 'hovel',
'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television',
'airplane', 'dirt_track', 'apparel', 'pole', 'land',
'bannister', 'escalator', 'ottoman', 'bottle', 'buffet',
'poster', 'stage', 'van', 'ship', 'fountain',
'conveyer_belt', 'canopy', 'washer', 'plaything',
'swimming_pool',
'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball',
'food', 'step', 'tank', 'trade_name', 'microwave',
'pot', 'animal', 'bicycle', 'lake', 'dishwasher',
'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic_light', 'tray', 'ashcan', 'fan',
'pier', 'crt_screen', 'plate', 'monitor',
'bulletin_board',
'shower', 'radiator', 'glass', 'clock', 'flag',
)
NUM_ADE20K_CLASSES = 150
ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8)
for i in range(1, NUM_ADE20K_CLASSES + 1):
hue = (i * 0.618033988749895) % 1.0
saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0
value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0
r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)]
# ── Model state (one model loaded at a time) ───────────────────────────────
_model = {
"name": None,
"vision": None,
"text": None,
"tokenizer": None,
"temperature": None,
"ade20k_embs": None,
"dpt": None,
}
def load_variant(name):
"""Load a DPT model variant from HuggingFace (includes the backbone)."""
global _model
if _model["name"] == name:
return
dpt = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True)
dpt.eval()
dpt._get_backbone() # trigger backbone download
backbone = dpt._backbone
_model.update(
name=name,
dpt=dpt,
vision=backbone.vision_encoder,
text=backbone.text_encoder,
tokenizer=backbone._load_tokenizer(),
temperature=backbone.config.temperature,
ade20k_embs=None,
)
print(f"Loaded {name}")
def _move_models_to_device():
"""Move models to the current device (GPU inside @spaces.GPU, else CPU)."""
dev = _device()
if _model["vision"] is not None:
_model["vision"].to(dev)
if _model["text"] is not None:
_model["text"].to(dev)
if _model["dpt"] is not None:
_model["dpt"].to(dev)
def _ensure_ade20k_embs():
"""Pre-compute Pascal Context text embeddings if not yet done (must run on GPU)."""
if _model["ade20k_embs"] is not None:
return
dev = _device()
model_t = _model["text"]
tokenizer = _model["tokenizer"]
all_embs = []
for template in TCL_PROMPTS:
prompts = [template.format(c) for c in PASCAL_CONTEXT_CLASSES]
ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN)
with torch.no_grad():
embs = model_t(torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev))
all_embs.append(embs.cpu().numpy())
_model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0))
print("Pascal Context text embeddings computed.")
def _init_model():
"""Load model + move to GPU + compute text embeddings."""
load_variant(_model["name"] or DEFAULT_VARIANT)
_move_models_to_device()
_ensure_ade20k_embs()
# ── Preprocessing & helpers ─────────────────────────────────────────────────
def preprocess(img, size=DEFAULT_IMAGE_SIZE):
return transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])(img)
def l2_normalize(x, axis=-1):
return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3)
def upsample(arr, h, w, mode="bilinear"):
"""Upsample (H, W, C) or (H, W) numpy array to (h, w, ...)."""
t = torch.from_numpy(arr).float()
if t.ndim == 2:
t = t.unsqueeze(-1)
t = t.permute(2, 0, 1).unsqueeze(0)
kwargs = dict(align_corners=False) if mode == "bilinear" else {}
up = F.interpolate(t, size=(h, w), mode=mode, **kwargs)
return up[0].permute(1, 2, 0).numpy()
def to_uint8(x):
return (x * 255).clip(0, 255).astype(np.uint8)
# ── Feature extraction (GPU-accelerated) ────────────────────────────────────
@torch.no_grad()
def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE):
"""Return spatial features (sp, sp, D) as numpy. sp = resolution // 14."""
dev = _device()
img = Image.fromarray(image_np).convert("RGB")
tensor = preprocess(img, resolution).unsqueeze(0).to(dev)
_, _, patch_tokens = _model["vision"](tensor)
sp = resolution // PATCH_SIZE
return patch_tokens.cpu().reshape(sp, sp, -1).numpy()
@torch.no_grad()
def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE):
"""Return spatial features (sp, sp, D) using Value Attention on GPU.
This follows the Colab reference implementation: run all blocks except the
last normally, then for the last block extract V from QKV and manually
apply out_proj, layer scale, residual, norm2, MLP + layer scale, second
residual, and final norm.
"""
dev = _device()
model_image = _model["vision"]
img = Image.fromarray(image_np).convert("RGB")
tensor = preprocess(img, resolution).unsqueeze(0).to(dev)
# Prepare tokens (patch embed + CLS + register + pos encoding)
x = model_image.prepare_tokens_with_masks(tensor)
# Run all blocks except the last one
for blk in model_image.blocks[:-1]:
x = blk(x)
# Last block: manually extract V and apply full pipeline
blk = model_image.blocks[-1]
num_reg = getattr(model_image, "num_register_tokens", 1)
# Compute QKV from the last block
b_dim, n_dim, c_dim = x.shape
num_heads = blk.attn.num_heads
qkv = blk.attn.qkv(blk.norm1(x))
qkv = qkv.reshape(b_dim, n_dim, 3, num_heads, c_dim // num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D_head)
# Extract V, reshape, and apply out_proj
v = qkv[2] # (B, H, N, D_head)
v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim)
v_out = blk.attn.proj(v_out) # out_proj projection
v_out = blk.ls1(v_out) # layer scale
x_val = v_out + x # residual connection
y_val = blk.norm2(x_val) # second norm
y_val = blk.ls2(blk.mlp(y_val)) # MLP + layer scale
x_val = x_val + y_val # second residual
x_val = model_image.norm(x_val) # final norm
# Extract patch tokens (skip CLS + register tokens)
patch_tokens = x_val[:, 1 + num_reg:, :]
sp = resolution // PATCH_SIZE
spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy()
return spatial
# ── PCA Visualisations ──────────────────────────────────────────────────────
def vis_pca(spatial, target_resolution):
"""PCA of spatial features β†’ RGB image."""
feat = spatial.reshape(-1, spatial.shape[-1])
pca = PCA(n_components=3, whiten=True)
H, W = spatial.shape[0], spatial.shape[1]
rgb = pca.fit_transform(feat).reshape(H, W, 3)
# Multiply by 2.0 and pass through a sigmoid to get vibrant colors
rgb = 1 / (1 + np.exp(-2.0 * rgb))
# return to_uint8(upsample(rgb, target_resolution, target_resolution))
return to_uint8(rgb)
def vis_depth(spatial, h, w):
"""1st PCA component visualized with inferno colormap."""
feat = spatial.reshape(-1, spatial.shape[-1])
H, W = spatial.shape[0], spatial.shape[1]
depth = PCA(n_components=1).fit_transform(feat).reshape(H, W)
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32)
return to_uint8(upsample(colored, h, w))
def vis_kmeans(spatial, h, w, n_clusters=6):
"""K-means clustering of spatial features."""
H, W = spatial.shape[:2]
feat = torch.from_numpy(spatial.reshape(-1, spatial.shape[-1])).to(_device())
km = TorchKMeans(n_clusters=n_clusters, max_iter=20)
km.fit(feat)
# Compute negative distances as scores, bilinear upsample, then argmax
dists = -torch.cdist(feat, km.centroids) # (H*W, k)
scores = dists.cpu().numpy().reshape(H, W, n_clusters)
scores_up = upsample(scores, h, w, mode="bilinear")
labels = scores_up.argmax(axis=-1)
palette = plt.cm.tab20(np.linspace(0, 1, n_clusters))[:, :3]
seg = palette[labels].astype(np.float32)
return to_uint8(seg)
# ── Zero-shot Segmentation ──────────────────────────────────────────────────
def vis_custom_semseg(spatial, orig_image, classes, class_embs):
"""Zero-shot semantic segmentation with user-defined classes."""
h, w = orig_image.shape[:2]
S_h, S_w = spatial.shape[:2]
n = len(classes)
feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1]))
sim = feat @ class_embs.T
sim_map = sim.reshape(S_h, S_w, n)
sim_up = upsample(sim_map, h, w, mode="bilinear")
labels = sim_up.argmax(axis=-1)
# Dynamic palette
palette = (plt.cm.tab20(np.linspace(0, 1, max(n, 2)))[:n, :3] * 255).astype(np.uint8)
seg_rgb = palette[labels].astype(np.float32) / 255.0
mask_img = to_uint8(seg_rgb)
blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb
blend_img = Image.fromarray(to_uint8(blend))
# Legend
unique_ids, counts = np.unique(labels, return_counts=True)
order = np.argsort(-counts)
unique_ids, counts = unique_ids[order], counts[order]
total = counts.sum()
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60)
except OSError:
font = ImageFont.load_default()
n_legend = min(len(unique_ids), 10)
row_h = 80
swatch_w = 60
pad = 12
legend_w = 450
legend_h = max(h, n_legend * row_h + pad * 2)
canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
canvas.paste(blend_img, (0, 0))
draw = ImageDraw.Draw(canvas)
for i in range(n_legend):
cid = unique_ids[i]
color = tuple(palette[cid].tolist())
y_top = pad + i * row_h
draw.rectangle(
[w + pad, y_top, w + pad + swatch_w, y_top + swatch_w],
fill=color, outline=(0, 0, 0),
)
draw.text(
(w + pad + swatch_w + 8, y_top + 6),
classes[cid], fill="black", font=font,
)
overlay_out = np.array(canvas)
detected_parts, minor_parts = [], []
for i, cid in enumerate(unique_ids):
pct = counts[i] / total * 100
if pct >= 2:
detected_parts.append(f"{classes[cid]} ({pct:.1f}%)")
else:
minor_parts.append(f"{classes[cid]} ({pct:.1f}%)")
absent = [
f"{classes[i]} (0.0%)" for i in range(n)
if i not in set(unique_ids.tolist())
]
detected_str = ", ".join(detected_parts)
undetected_str = ", ".join(minor_parts + absent)
return overlay_out, mask_img, detected_str, undetected_str
# ── DPT Depth Inference ─────────────────────────────────────────────────────
def vis_depth_dpt(depth_map, h, w):
"""Colour a depth map with the turbo colormap β†’ PIL Image."""
d = depth_map.squeeze()
d = (d - d.min()) / (d.max() - d.min() + 1e-8)
colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32)
return to_uint8(upsample(colored, h, w))
def vis_normals_dpt(normals_map, h, w):
"""Map normals from [-1, 1] to [0, 1] and resize to original size."""
# normals_map shape is (3, H, W)
n = normals_map.cpu().numpy()
n = (n + 1.0) / 2.0 # Map to [0, 1]
n = np.transpose(n, (1, 2, 0)) # (H, W, 3)
return to_uint8(upsample(n, h, w))
def vis_segmentation_dpt(seg_map, orig_image):
"""Colour a segmentation map with the ADE20K colormap + legend."""
h, w = orig_image.shape[:2]
logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150)
logits_up = upsample(logits, h, w, mode="bilinear")
pred = logits_up.argmax(axis=-1) # (h, w)
seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0
blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb
blend_img = Image.fromarray(to_uint8(blend))
# Legend: classes with >= 2% area, sorted by area descending
unique_ids, counts = np.unique(pred, return_counts=True)
total_pixels = counts.sum()
order = np.argsort(-counts)
unique_ids, counts = unique_ids[order], counts[order]
# Filter to only classes occupying >= 2% of the image
pcts = counts / total_pixels * 100
mask = pcts >= 2.0
unique_ids, counts, pcts = unique_ids[mask], counts[mask], pcts[mask]
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36)
except OSError:
font = ImageFont.load_default()
n_legend = min(len(unique_ids), 10)
row_h, swatch_w, pad, legend_w = 50, 40, 10, 450
legend_h = max(h, n_legend * row_h + pad * 2)
canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
canvas.paste(blend_img, (0, 0))
draw = ImageDraw.Draw(canvas)
for i in range(n_legend):
cid = unique_ids[i]
color = tuple(ADE20K_PALETTE[cid + 1].tolist())
name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}"
y_top = pad + i * row_h
draw.rectangle([w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], fill=color, outline=(0, 0, 0))
draw.text((w + pad + swatch_w + 8, y_top + 4), name, fill="black", font=font)
return np.array(canvas)
# ── Gradio callbacks ────────────────────────────────────────────────────────
@spaces.GPU
def on_variant_change(variant_name):
load_variant(variant_name)
_move_models_to_device()
_ensure_ade20k_embs()
# _ensure_voc_embs() # VOC tab hidden for now
# Clear PCA & zero-shot outputs (depth/seg tabs unaffected by variant)
return (None, None, None, # pca_out, depth_out, kmeans_out
None, # pca_state
None, None, "", "") # custom outputs
# --- PCA tab callbacks ---
@spaces.GPU
def on_pca_extract(image, resolution, pca_state):
if image is None:
return None, None, None, None
_init_model()
resolution = int(resolution)
spatial = extract_features(image, resolution)
h, w = image.shape[:2]
pca = vis_pca(spatial, resolution)
depth = vis_depth(spatial, h, w)
kmeans = vis_kmeans(spatial, h, w)
state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution}
return pca, depth, kmeans, state
@spaces.GPU
def on_recluster(image, resolution, n_clusters, pca_state):
if image is None:
gr.Warning("Upload an image first.")
return None, pca_state
_init_model()
resolution = int(resolution)
if (pca_state is not None
and pca_state.get("variant") == _model["name"]
and pca_state.get("resolution") == resolution):
spatial = pca_state["spatial"]
else:
spatial = extract_features(image, resolution)
pca_state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution}
h, w = image.shape[:2]
return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state
# --- Zero-shot Segmentation tab callbacks ---
@spaces.GPU
def on_zeroseg_custom(image, resolution, class_names_str):
if image is None or not class_names_str or not class_names_str.strip():
gr.Warning("Upload an image and enter at least one class name.")
return None, None, "", ""
_init_model()
resolution = int(resolution)
classes = [c.strip() for c in class_names_str.split(",") if c.strip()]
if not classes:
return None, None, "", ""
# Encode custom classes with TCL prompt templates
dev = _device()
all_embs = []
for template in TCL_PROMPTS:
prompts = [template.format(c) for c in classes]
ids, paddings = _model["tokenizer"].tokenize(prompts, max_len=MAX_LEN)
with torch.no_grad():
embs = _model["text"](torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev))
all_embs.append(embs.cpu().numpy())
class_embs = l2_normalize(np.mean(all_embs, axis=0))
spatial = extract_features_value_attention(image, resolution)
overlay, mask, detected, undetected = vis_custom_semseg(spatial, image, classes, class_embs)
return overlay, mask, detected, undetected
# --- Depth Feature Visualization tab callbacks ---
@spaces.GPU
def on_depth_normals_predict(image, dpt_variant, resolution): # noqa: ARG001
"""Run DPT depth and normals prediction."""
if image is None:
return None, None
_init_model()
dev = _device()
h, w = image.shape[:2]
img = Image.fromarray(image).convert("RGB")
tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
depth_map = _model["dpt"].predict_depth(tensor)
normals_map = _model["dpt"].predict_normals(tensor)
return vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w), vis_normals_dpt(normals_map[0], h, w)
@spaces.GPU
def on_segmentation_predict(image, dpt_variant, resolution): # noqa: ARG001
"""Run DPT segmentation prediction."""
if image is None:
return None
_init_model()
dev = _device()
h, w = image.shape[:2]
img = Image.fromarray(image).convert("RGB")
tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
seg_map = _model["dpt"].predict_segmentation(tensor)
return vis_segmentation_dpt(seg_map[0], image)
# ── UI ──────────────────────────────────────────────────────────────────────
custom_css = """
#pca_output_image img {
image-rendering: pixelated;
object-fit: contain;
}
"""
head = """
<!-- Google tag (gtag.js) -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-P13E18K71N"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-P13E18K71N');
</script>
"""
# with gr.Blocks(head=head, title="TIPSv2 Feature Explorer") as demo:
with gr.Blocks(head=head, title="TIPSv2 Feature Explorer", css=custom_css) as demo:
gr.Markdown(
"## TIPSv2 Feature Explorer\n"
"Explore TIPSv2 representations here! For more information, see: "
"https://gdm-tipsv2.github.io/"
)
with gr.Row():
variant_dd = gr.Dropdown(
choices=list(VARIANTS.keys()),
value=DEFAULT_VARIANT,
label="Model variant",
)
resolution_dd = gr.Dropdown(
choices=RESOLUTIONS,
value=DEFAULT_IMAGE_SIZE,
label="Resolution (higher = better quality, slower)",
)
# ── PCA / Feature Visualization Tab ─────────────────────────────────
with gr.Tab("🎨 PCA & Feature Visualization"):
pca_state = gr.State(None)
with gr.Row():
with gr.Column():
pca_input = gr.Image(type="numpy", label="Input image")
pca_btn = gr.Button("Extract Features", variant="primary")
with gr.Column():
with gr.Tabs():
with gr.Tab("PCA"):
pca_out = gr.Image(label="PCA (3 components β†’ RGB)", height=448, elem_id="pca_output_image")
# pca_out = gr.Image(label="PCA (3 components β†’ RGB)")
with gr.Tab("PCA (1st component)"):
depth_out = gr.Image(label="1st PCA component")
with gr.Tab("K-means Clustering"):
n_clusters = gr.Slider(2, 20, value=6, step=1, label="Clusters")
recluster_btn = gr.Button("Re-cluster")
kmeans_out = gr.Image(label="K-means clusters")
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/pca/hike.jpeg"],
["examples/pca/cph.jpeg"],
["examples/pca/angus.jpeg"],
["examples/pca/dadaocheng.jpeg"],
],
inputs=[pca_input],
)
# ── Zero-shot Segmentation Tab ──────────────────────────────────────
with gr.Tab("✏️ Zero-shot Segmentation"):
gr.Markdown(
"Define your own classes for zero-shot segmentation. "
"Enter class names separated by commas."
)
with gr.Row():
with gr.Column():
custom_input = gr.Image(type="numpy", label="Input image", height=448)
custom_classes = gr.Textbox(
label="Class names (comma-separated)",
value="class1, class2, class3",
placeholder="e.g. cat, dog, sky, grass",
)
custom_btn = gr.Button("Segment", variant="primary")
with gr.Column():
with gr.Tabs():
with gr.Tab("Overlay"):
custom_overlay = gr.Image(label="Segmentation overlay", height=448)
with gr.Tab("Mask"):
custom_mask = gr.Image(label="Segmentation mask", height=448)
custom_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=2)
custom_undetected = gr.Textbox(label="Not detected", lines=2)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/zeroseg_voc/voc_2008_000891.jpg", "dog, cage, cloth, dog bowl"],
["examples/zeroseg/pascal_context_00000_image.png", "bike, tree, fence, soccer, floor, chair, cushion"],
["examples/zeroseg/pascal_context_00007_image.png", "dog, table, chair, carpet, shoes"],
["examples/zeroseg/pascal_context_00049_image.png", "bus, snow, mountain, house, road"],
],
inputs=[custom_input, custom_classes],
)
# ── Pascal Context & VOC tabs (hidden for now) ──────────────────────
# with gr.Tab("πŸ—ΊοΈ Zero-shot Segmentation (Pascal Context)"):
# gr.Markdown(
# "Zero-shot semantic segmentation using **Value Attention** features "
# "(MaskCLIP style) with 9 TCL prompt templates ensembled over "
# "**Pascal Context 59 classes**. No segmentation head β€” purely from "
# "vision–language alignment."
# )
#
# with gr.Row():
# with gr.Column():
# zs_input = gr.Image(type="numpy", label="Input image", height=448)
# zs_btn = gr.Button("Segment", variant="primary")
#
# with gr.Column():
# with gr.Tabs():
# with gr.Tab("Overlay"):
# zs_overlay = gr.Image(label="Segmentation overlay", height=448)
# with gr.Tab("Mask"):
# zs_mask = gr.Image(label="Segmentation mask (raw)", height=448)
# zs_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=3)
# zs_undetected = gr.Textbox(label="Not detected", lines=2)
#
# gr.Examples(
# examples=[
# ["examples/zeroseg/bus.png"],
# ["examples/zeroseg/birds.png"],
# ["examples/zeroseg/bicycle.png"],
# ["examples/zeroseg/baby.png"],
# ["examples/zeroseg/dog.png"],
# ["examples/zeroseg/sleep.png"],
# ["examples/zeroseg/pascal_context_00007.png"],
# ["examples/zeroseg/pascal_context_00029.png"],
# ["examples/zeroseg/pascal_context_00068.png"],
# ],
# inputs=[zs_input],
# )
# ── Zero-shot Segmentation (Pascal VOC) Tab ─────────────────────────
# with gr.Tab("🏷️ Zero-shot Segmentation (Pascal VOC)"):
# gr.Markdown(
# "Zero-shot semantic segmentation using **Value Attention** features "
# "(MaskCLIP style) with 9 TCL prompt templates ensembled over "
# "**Pascal VOC 20 classes**. No segmentation head β€” purely from "
# "vision–language alignment."
# )
#
# with gr.Row():
# with gr.Column():
# voc_input = gr.Image(type="numpy", label="Input image", height=448)
# voc_btn = gr.Button("Segment", variant="primary")
#
# with gr.Column():
# with gr.Tabs():
# with gr.Tab("Overlay"):
# voc_overlay = gr.Image(label="Segmentation overlay", height=448)
# with gr.Tab("Mask"):
# voc_mask = gr.Image(label="Segmentation mask (raw)", height=448)
# voc_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=3)
# voc_undetected = gr.Textbox(label="Not detected", lines=2)
#
# gr.Examples(
# examples=[
# ["examples/zeroseg_voc/voc_2008_000012.jpg"],
# ["examples/zeroseg_voc/voc_2008_000044.jpg"],
# ["examples/zeroseg_voc/voc_2008_000159.jpg"],
# ["examples/zeroseg_voc/voc_2008_000167.jpg"],
# ["examples/zeroseg_voc/voc_2008_000712.jpg"],
# ["examples/zeroseg_voc/voc_2008_000768.jpg"],
# ["examples/zeroseg_voc/voc_2008_000891.jpg"],
# ["examples/zeroseg_voc/voc_2008_001365.jpg"],
# ],
# inputs=[voc_input],
# )
# ── Zero-shot Segmentation (Custom) Tab ──────────────────────────────
# ── Depth Feature Visualization Tab ─────────────────────────────────
with gr.Tab("πŸ”οΈ Depth/Normals Visualization"):
gr.Markdown(
"Monocular depth and surface normals estimation using a **DPT "
"(Dense Prediction Transformer)** head on top of a **frozen** "
"TIPS v2 vision encoder. Trained on the **NYU Depth V2** dataset."
)
with gr.Row():
with gr.Column():
depth_input = gr.Image(type="numpy", label="Input image", height=448)
depth_btn = gr.Button("Predict Depth & Normals", variant="primary")
with gr.Column():
dpt_depth_out = gr.Image(label="DPT Depth Map", height=448)
with gr.Column():
dpt_normals_out = gr.Image(label="DPT Surface Normals", height=448)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/nyuv2/bedroom_00280.jpg"],
["examples/nyuv2/kitchen_00249.jpg"],
["examples/nyuv2/living_room_01260.jpg"],
["examples/nyuv2/office_kitchen_00413.jpg"],
["examples/nyuv2/study_room_00272.jpg"],
],
inputs=[depth_input],
)
# ── Supervised Segmentation Tab ──────────────────────────────────────
with gr.Tab("🎭 Supervised Segmentation"):
gr.Markdown(
"Semantic segmentation using a **DPT (Dense Prediction "
"Transformer)** head on top of a **frozen** TIPS v2 vision "
"encoder. Trained on ADE20K (150 classes)."
)
with gr.Row():
with gr.Column():
seg_input = gr.Image(type="numpy", label="Input image", height=448)
seg_btn = gr.Button("Segment", variant="primary")
with gr.Column():
seg_out = gr.Image(label="DPT Segmentation (ADE20K)", height=448)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/depth/ade20k_00003.png"],
["examples/depth/ade20k_00007.png"],
["examples/depth/ade20k_00014.png"],
["examples/depth/ade20k_00022.png"],
],
inputs=[seg_input],
)
# ── Wiring ──────────────────────────────────────────────────────────
variant_dd.change(
fn=on_variant_change,
inputs=[variant_dd],
outputs=[
pca_out, depth_out, kmeans_out, # PCA outputs
pca_state, # PCA state
custom_overlay, custom_mask, # Custom ZS outputs
custom_detected, custom_undetected, # Custom ZS text
],
)
pca_btn.click(
fn=on_pca_extract,
inputs=[pca_input, resolution_dd, pca_state],
outputs=[pca_out, depth_out, kmeans_out, pca_state],
)
recluster_btn.click(
fn=on_recluster,
inputs=[pca_input, resolution_dd, n_clusters, pca_state],
outputs=[kmeans_out, pca_state],
)
# zs_btn.click(
# fn=on_zeroseg,
# inputs=[zs_input, resolution_dd],
# outputs=[zs_overlay, zs_mask, zs_detected, zs_undetected],
# )
depth_btn.click(
fn=on_depth_normals_predict,
inputs=[depth_input, variant_dd, resolution_dd],
outputs=[dpt_depth_out, dpt_normals_out],
)
seg_btn.click(
fn=on_segmentation_predict,
inputs=[seg_input, variant_dd, resolution_dd],
outputs=[seg_out],
)
# voc_btn.click(
# fn=on_zeroseg_voc,
# inputs=[voc_input, resolution_dd],
# outputs=[voc_overlay, voc_mask, voc_detected, voc_undetected],
# )
custom_btn.click(
fn=on_zeroseg_custom,
inputs=[custom_input, resolution_dd, custom_classes],
outputs=[custom_overlay, custom_mask, custom_detected, custom_undetected],
)
if __name__ == "__main__":
demo.launch()