Spaces:
Running on Zero
Running on Zero
| """TIPS Feature Explorer (GPU) β Hugging Face Space demo with ZeroGPU.""" | |
| import colorsys | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw, ImageFont | |
| from fast_pytorch_kmeans import KMeans as TorchKMeans | |
| from sklearn.decomposition import PCA | |
| from torchvision import transforms | |
| from transformers import AutoModel | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_IMAGE_SIZE = 896 | |
| PATCH_SIZE = 14 | |
| RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792] | |
| ZEROSEG_IMAGE_SIZE = 1372 | |
| MAX_LEN = 64 | |
| # HF model repos (DPT repos include the backbone) | |
| VARIANTS = { | |
| "TIPS v2 β B/14": "google/tipsv2-b14-dpt", | |
| "TIPS v2 β L/14": "google/tipsv2-l14-dpt", | |
| "TIPS v2 β SO400m/14": "google/tipsv2-so400m14-dpt", | |
| "TIPS v2 β g/14": "google/tipsv2-g14-dpt", | |
| } | |
| DEFAULT_VARIANT = "TIPS v2 β L/14" | |
| def _device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ββ Pascal Context (59 classes) βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TCL prompt templates (from the Scenic zero-shot seg evaluator). | |
| TCL_PROMPTS = [ | |
| "itap of a {}.", | |
| "a bad photo of a {}.", | |
| "a origami {}.", | |
| "a photo of the large {}.", | |
| "a {} in a video game.", | |
| "art of the {}.", | |
| "a photo of the small {}.", | |
| "a photo of many {}.", | |
| "a photo of {}s.", | |
| ] | |
| # Class names (excluding background at index 0) from segmentation_dataset_info.py | |
| PASCAL_CONTEXT_CLASSES = ( | |
| "aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", | |
| "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", | |
| "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", | |
| "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", | |
| "horse", "keyboard", "light", "motorbike", "mountain", "mouse", | |
| "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", | |
| "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "table", "track", | |
| "train", "tree", "truck", "tvmonitor", "wall", "water", "window", | |
| "wood", | |
| ) | |
| ADE20K_CLASSES = ( | |
| 'wall', 'building', 'sky', 'floor', 'tree', | |
| 'ceiling', 'road', 'bed', 'windowpane', 'grass', | |
| 'cabinet', 'sidewalk', 'person', 'earth', 'door', | |
| 'table', 'mountain', 'plant', 'curtain', 'chair', | |
| 'car', 'water', 'painting', 'sofa', 'shelf', | |
| 'house', 'sea', 'mirror', 'rug', 'field', | |
| 'armchair', 'seat', 'fence', 'desk', 'rock', | |
| 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', | |
| 'base', 'box', 'column', 'signboard', 'chest_of_drawers', | |
| 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', | |
| 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', | |
| 'case', 'pool_table', 'pillow', 'screen_door', 'stairway', | |
| 'river', 'bridge', 'bookcase', 'blind', 'coffee_table', | |
| 'toilet', 'flower', 'book', 'hill', 'bench', | |
| 'countertop', 'stove', 'palm', 'kitchen_island', 'computer', | |
| 'swivel_chair', 'boat', 'bar', 'arcade_machine', 'hovel', | |
| 'bus', 'towel', 'light', 'truck', 'tower', | |
| 'chandelier', 'awning', 'streetlight', 'booth', 'television', | |
| 'airplane', 'dirt_track', 'apparel', 'pole', 'land', | |
| 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', | |
| 'poster', 'stage', 'van', 'ship', 'fountain', | |
| 'conveyer_belt', 'canopy', 'washer', 'plaything', | |
| 'swimming_pool', | |
| 'stool', 'barrel', 'basket', 'waterfall', 'tent', | |
| 'bag', 'minibike', 'cradle', 'oven', 'ball', | |
| 'food', 'step', 'tank', 'trade_name', 'microwave', | |
| 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', | |
| 'screen', 'blanket', 'sculpture', 'hood', 'sconce', | |
| 'vase', 'traffic_light', 'tray', 'ashcan', 'fan', | |
| 'pier', 'crt_screen', 'plate', 'monitor', | |
| 'bulletin_board', | |
| 'shower', 'radiator', 'glass', 'clock', 'flag', | |
| ) | |
| NUM_ADE20K_CLASSES = 150 | |
| ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8) | |
| for i in range(1, NUM_ADE20K_CLASSES + 1): | |
| hue = (i * 0.618033988749895) % 1.0 | |
| saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0 | |
| value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0 | |
| r, g, b = colorsys.hsv_to_rgb(hue, saturation, value) | |
| ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)] | |
| # ββ Model state (one model loaded at a time) βββββββββββββββββββββββββββββββ | |
| _model = { | |
| "name": None, | |
| "vision": None, | |
| "text": None, | |
| "tokenizer": None, | |
| "temperature": None, | |
| "ade20k_embs": None, | |
| "dpt": None, | |
| } | |
| def load_variant(name): | |
| """Load a DPT model variant from HuggingFace (includes the backbone).""" | |
| global _model | |
| if _model["name"] == name: | |
| return | |
| dpt = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True) | |
| dpt.eval() | |
| dpt._get_backbone() # trigger backbone download | |
| backbone = dpt._backbone | |
| _model.update( | |
| name=name, | |
| dpt=dpt, | |
| vision=backbone.vision_encoder, | |
| text=backbone.text_encoder, | |
| tokenizer=backbone._load_tokenizer(), | |
| temperature=backbone.config.temperature, | |
| ade20k_embs=None, | |
| ) | |
| print(f"Loaded {name}") | |
| def _move_models_to_device(): | |
| """Move models to the current device (GPU inside @spaces.GPU, else CPU).""" | |
| dev = _device() | |
| if _model["vision"] is not None: | |
| _model["vision"].to(dev) | |
| if _model["text"] is not None: | |
| _model["text"].to(dev) | |
| if _model["dpt"] is not None: | |
| _model["dpt"].to(dev) | |
| def _ensure_ade20k_embs(): | |
| """Pre-compute Pascal Context text embeddings if not yet done (must run on GPU).""" | |
| if _model["ade20k_embs"] is not None: | |
| return | |
| dev = _device() | |
| model_t = _model["text"] | |
| tokenizer = _model["tokenizer"] | |
| all_embs = [] | |
| for template in TCL_PROMPTS: | |
| prompts = [template.format(c) for c in PASCAL_CONTEXT_CLASSES] | |
| ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN) | |
| with torch.no_grad(): | |
| embs = model_t(torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev)) | |
| all_embs.append(embs.cpu().numpy()) | |
| _model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0)) | |
| print("Pascal Context text embeddings computed.") | |
| def _init_model(): | |
| """Load model + move to GPU + compute text embeddings.""" | |
| load_variant(_model["name"] or DEFAULT_VARIANT) | |
| _move_models_to_device() | |
| _ensure_ade20k_embs() | |
| # ββ Preprocessing & helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(img, size=DEFAULT_IMAGE_SIZE): | |
| return transforms.Compose([ | |
| transforms.Resize((size, size)), | |
| transforms.ToTensor(), | |
| ])(img) | |
| def l2_normalize(x, axis=-1): | |
| return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3) | |
| def upsample(arr, h, w, mode="bilinear"): | |
| """Upsample (H, W, C) or (H, W) numpy array to (h, w, ...).""" | |
| t = torch.from_numpy(arr).float() | |
| if t.ndim == 2: | |
| t = t.unsqueeze(-1) | |
| t = t.permute(2, 0, 1).unsqueeze(0) | |
| kwargs = dict(align_corners=False) if mode == "bilinear" else {} | |
| up = F.interpolate(t, size=(h, w), mode=mode, **kwargs) | |
| return up[0].permute(1, 2, 0).numpy() | |
| def to_uint8(x): | |
| return (x * 255).clip(0, 255).astype(np.uint8) | |
| # ββ Feature extraction (GPU-accelerated) ββββββββββββββββββββββββββββββββββββ | |
| def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE): | |
| """Return spatial features (sp, sp, D) as numpy. sp = resolution // 14.""" | |
| dev = _device() | |
| img = Image.fromarray(image_np).convert("RGB") | |
| tensor = preprocess(img, resolution).unsqueeze(0).to(dev) | |
| _, _, patch_tokens = _model["vision"](tensor) | |
| sp = resolution // PATCH_SIZE | |
| return patch_tokens.cpu().reshape(sp, sp, -1).numpy() | |
| def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE): | |
| """Return spatial features (sp, sp, D) using Value Attention on GPU. | |
| This follows the Colab reference implementation: run all blocks except the | |
| last normally, then for the last block extract V from QKV and manually | |
| apply out_proj, layer scale, residual, norm2, MLP + layer scale, second | |
| residual, and final norm. | |
| """ | |
| dev = _device() | |
| model_image = _model["vision"] | |
| img = Image.fromarray(image_np).convert("RGB") | |
| tensor = preprocess(img, resolution).unsqueeze(0).to(dev) | |
| # Prepare tokens (patch embed + CLS + register + pos encoding) | |
| x = model_image.prepare_tokens_with_masks(tensor) | |
| # Run all blocks except the last one | |
| for blk in model_image.blocks[:-1]: | |
| x = blk(x) | |
| # Last block: manually extract V and apply full pipeline | |
| blk = model_image.blocks[-1] | |
| num_reg = getattr(model_image, "num_register_tokens", 1) | |
| # Compute QKV from the last block | |
| b_dim, n_dim, c_dim = x.shape | |
| num_heads = blk.attn.num_heads | |
| qkv = blk.attn.qkv(blk.norm1(x)) | |
| qkv = qkv.reshape(b_dim, n_dim, 3, num_heads, c_dim // num_heads) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D_head) | |
| # Extract V, reshape, and apply out_proj | |
| v = qkv[2] # (B, H, N, D_head) | |
| v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim) | |
| v_out = blk.attn.proj(v_out) # out_proj projection | |
| v_out = blk.ls1(v_out) # layer scale | |
| x_val = v_out + x # residual connection | |
| y_val = blk.norm2(x_val) # second norm | |
| y_val = blk.ls2(blk.mlp(y_val)) # MLP + layer scale | |
| x_val = x_val + y_val # second residual | |
| x_val = model_image.norm(x_val) # final norm | |
| # Extract patch tokens (skip CLS + register tokens) | |
| patch_tokens = x_val[:, 1 + num_reg:, :] | |
| sp = resolution // PATCH_SIZE | |
| spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy() | |
| return spatial | |
| # ββ PCA Visualisations ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_pca(spatial, target_resolution): | |
| """PCA of spatial features β RGB image.""" | |
| feat = spatial.reshape(-1, spatial.shape[-1]) | |
| pca = PCA(n_components=3, whiten=True) | |
| H, W = spatial.shape[0], spatial.shape[1] | |
| rgb = pca.fit_transform(feat).reshape(H, W, 3) | |
| # Multiply by 2.0 and pass through a sigmoid to get vibrant colors | |
| rgb = 1 / (1 + np.exp(-2.0 * rgb)) | |
| # return to_uint8(upsample(rgb, target_resolution, target_resolution)) | |
| return to_uint8(rgb) | |
| def vis_depth(spatial, h, w): | |
| """1st PCA component visualized with inferno colormap.""" | |
| feat = spatial.reshape(-1, spatial.shape[-1]) | |
| H, W = spatial.shape[0], spatial.shape[1] | |
| depth = PCA(n_components=1).fit_transform(feat).reshape(H, W) | |
| depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) | |
| colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32) | |
| return to_uint8(upsample(colored, h, w)) | |
| def vis_kmeans(spatial, h, w, n_clusters=6): | |
| """K-means clustering of spatial features.""" | |
| H, W = spatial.shape[:2] | |
| feat = torch.from_numpy(spatial.reshape(-1, spatial.shape[-1])).to(_device()) | |
| km = TorchKMeans(n_clusters=n_clusters, max_iter=20) | |
| km.fit(feat) | |
| # Compute negative distances as scores, bilinear upsample, then argmax | |
| dists = -torch.cdist(feat, km.centroids) # (H*W, k) | |
| scores = dists.cpu().numpy().reshape(H, W, n_clusters) | |
| scores_up = upsample(scores, h, w, mode="bilinear") | |
| labels = scores_up.argmax(axis=-1) | |
| palette = plt.cm.tab20(np.linspace(0, 1, n_clusters))[:, :3] | |
| seg = palette[labels].astype(np.float32) | |
| return to_uint8(seg) | |
| # ββ Zero-shot Segmentation ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_custom_semseg(spatial, orig_image, classes, class_embs): | |
| """Zero-shot semantic segmentation with user-defined classes.""" | |
| h, w = orig_image.shape[:2] | |
| S_h, S_w = spatial.shape[:2] | |
| n = len(classes) | |
| feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1])) | |
| sim = feat @ class_embs.T | |
| sim_map = sim.reshape(S_h, S_w, n) | |
| sim_up = upsample(sim_map, h, w, mode="bilinear") | |
| labels = sim_up.argmax(axis=-1) | |
| # Dynamic palette | |
| palette = (plt.cm.tab20(np.linspace(0, 1, max(n, 2)))[:n, :3] * 255).astype(np.uint8) | |
| seg_rgb = palette[labels].astype(np.float32) / 255.0 | |
| mask_img = to_uint8(seg_rgb) | |
| blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb | |
| blend_img = Image.fromarray(to_uint8(blend)) | |
| # Legend | |
| unique_ids, counts = np.unique(labels, return_counts=True) | |
| order = np.argsort(-counts) | |
| unique_ids, counts = unique_ids[order], counts[order] | |
| total = counts.sum() | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| n_legend = min(len(unique_ids), 10) | |
| row_h = 80 | |
| swatch_w = 60 | |
| pad = 12 | |
| legend_w = 450 | |
| legend_h = max(h, n_legend * row_h + pad * 2) | |
| canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) | |
| canvas.paste(blend_img, (0, 0)) | |
| draw = ImageDraw.Draw(canvas) | |
| for i in range(n_legend): | |
| cid = unique_ids[i] | |
| color = tuple(palette[cid].tolist()) | |
| y_top = pad + i * row_h | |
| draw.rectangle( | |
| [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], | |
| fill=color, outline=(0, 0, 0), | |
| ) | |
| draw.text( | |
| (w + pad + swatch_w + 8, y_top + 6), | |
| classes[cid], fill="black", font=font, | |
| ) | |
| overlay_out = np.array(canvas) | |
| detected_parts, minor_parts = [], [] | |
| for i, cid in enumerate(unique_ids): | |
| pct = counts[i] / total * 100 | |
| if pct >= 2: | |
| detected_parts.append(f"{classes[cid]} ({pct:.1f}%)") | |
| else: | |
| minor_parts.append(f"{classes[cid]} ({pct:.1f}%)") | |
| absent = [ | |
| f"{classes[i]} (0.0%)" for i in range(n) | |
| if i not in set(unique_ids.tolist()) | |
| ] | |
| detected_str = ", ".join(detected_parts) | |
| undetected_str = ", ".join(minor_parts + absent) | |
| return overlay_out, mask_img, detected_str, undetected_str | |
| # ββ DPT Depth Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_depth_dpt(depth_map, h, w): | |
| """Colour a depth map with the turbo colormap β PIL Image.""" | |
| d = depth_map.squeeze() | |
| d = (d - d.min()) / (d.max() - d.min() + 1e-8) | |
| colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32) | |
| return to_uint8(upsample(colored, h, w)) | |
| def vis_normals_dpt(normals_map, h, w): | |
| """Map normals from [-1, 1] to [0, 1] and resize to original size.""" | |
| # normals_map shape is (3, H, W) | |
| n = normals_map.cpu().numpy() | |
| n = (n + 1.0) / 2.0 # Map to [0, 1] | |
| n = np.transpose(n, (1, 2, 0)) # (H, W, 3) | |
| return to_uint8(upsample(n, h, w)) | |
| def vis_segmentation_dpt(seg_map, orig_image): | |
| """Colour a segmentation map with the ADE20K colormap + legend.""" | |
| h, w = orig_image.shape[:2] | |
| logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150) | |
| logits_up = upsample(logits, h, w, mode="bilinear") | |
| pred = logits_up.argmax(axis=-1) # (h, w) | |
| seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0 | |
| blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb | |
| blend_img = Image.fromarray(to_uint8(blend)) | |
| # Legend: classes with >= 2% area, sorted by area descending | |
| unique_ids, counts = np.unique(pred, return_counts=True) | |
| total_pixels = counts.sum() | |
| order = np.argsort(-counts) | |
| unique_ids, counts = unique_ids[order], counts[order] | |
| # Filter to only classes occupying >= 2% of the image | |
| pcts = counts / total_pixels * 100 | |
| mask = pcts >= 2.0 | |
| unique_ids, counts, pcts = unique_ids[mask], counts[mask], pcts[mask] | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| n_legend = min(len(unique_ids), 10) | |
| row_h, swatch_w, pad, legend_w = 50, 40, 10, 450 | |
| legend_h = max(h, n_legend * row_h + pad * 2) | |
| canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) | |
| canvas.paste(blend_img, (0, 0)) | |
| draw = ImageDraw.Draw(canvas) | |
| for i in range(n_legend): | |
| cid = unique_ids[i] | |
| color = tuple(ADE20K_PALETTE[cid + 1].tolist()) | |
| name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}" | |
| y_top = pad + i * row_h | |
| draw.rectangle([w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], fill=color, outline=(0, 0, 0)) | |
| draw.text((w + pad + swatch_w + 8, y_top + 4), name, fill="black", font=font) | |
| return np.array(canvas) | |
| # ββ Gradio callbacks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_variant_change(variant_name): | |
| load_variant(variant_name) | |
| _move_models_to_device() | |
| _ensure_ade20k_embs() | |
| # _ensure_voc_embs() # VOC tab hidden for now | |
| # Clear PCA & zero-shot outputs (depth/seg tabs unaffected by variant) | |
| return (None, None, None, # pca_out, depth_out, kmeans_out | |
| None, # pca_state | |
| None, None, "", "") # custom outputs | |
| # --- PCA tab callbacks --- | |
| def on_pca_extract(image, resolution, pca_state): | |
| if image is None: | |
| return None, None, None, None | |
| _init_model() | |
| resolution = int(resolution) | |
| spatial = extract_features(image, resolution) | |
| h, w = image.shape[:2] | |
| pca = vis_pca(spatial, resolution) | |
| depth = vis_depth(spatial, h, w) | |
| kmeans = vis_kmeans(spatial, h, w) | |
| state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution} | |
| return pca, depth, kmeans, state | |
| def on_recluster(image, resolution, n_clusters, pca_state): | |
| if image is None: | |
| gr.Warning("Upload an image first.") | |
| return None, pca_state | |
| _init_model() | |
| resolution = int(resolution) | |
| if (pca_state is not None | |
| and pca_state.get("variant") == _model["name"] | |
| and pca_state.get("resolution") == resolution): | |
| spatial = pca_state["spatial"] | |
| else: | |
| spatial = extract_features(image, resolution) | |
| pca_state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution} | |
| h, w = image.shape[:2] | |
| return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state | |
| # --- Zero-shot Segmentation tab callbacks --- | |
| def on_zeroseg_custom(image, resolution, class_names_str): | |
| if image is None or not class_names_str or not class_names_str.strip(): | |
| gr.Warning("Upload an image and enter at least one class name.") | |
| return None, None, "", "" | |
| _init_model() | |
| resolution = int(resolution) | |
| classes = [c.strip() for c in class_names_str.split(",") if c.strip()] | |
| if not classes: | |
| return None, None, "", "" | |
| # Encode custom classes with TCL prompt templates | |
| dev = _device() | |
| all_embs = [] | |
| for template in TCL_PROMPTS: | |
| prompts = [template.format(c) for c in classes] | |
| ids, paddings = _model["tokenizer"].tokenize(prompts, max_len=MAX_LEN) | |
| with torch.no_grad(): | |
| embs = _model["text"](torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev)) | |
| all_embs.append(embs.cpu().numpy()) | |
| class_embs = l2_normalize(np.mean(all_embs, axis=0)) | |
| spatial = extract_features_value_attention(image, resolution) | |
| overlay, mask, detected, undetected = vis_custom_semseg(spatial, image, classes, class_embs) | |
| return overlay, mask, detected, undetected | |
| # --- Depth Feature Visualization tab callbacks --- | |
| def on_depth_normals_predict(image, dpt_variant, resolution): # noqa: ARG001 | |
| """Run DPT depth and normals prediction.""" | |
| if image is None: | |
| return None, None | |
| _init_model() | |
| dev = _device() | |
| h, w = image.shape[:2] | |
| img = Image.fromarray(image).convert("RGB") | |
| tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) | |
| depth_map = _model["dpt"].predict_depth(tensor) | |
| normals_map = _model["dpt"].predict_normals(tensor) | |
| return vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w), vis_normals_dpt(normals_map[0], h, w) | |
| def on_segmentation_predict(image, dpt_variant, resolution): # noqa: ARG001 | |
| """Run DPT segmentation prediction.""" | |
| if image is None: | |
| return None | |
| _init_model() | |
| dev = _device() | |
| h, w = image.shape[:2] | |
| img = Image.fromarray(image).convert("RGB") | |
| tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) | |
| seg_map = _model["dpt"].predict_segmentation(tensor) | |
| return vis_segmentation_dpt(seg_map[0], image) | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| custom_css = """ | |
| #pca_output_image img { | |
| image-rendering: pixelated; | |
| object-fit: contain; | |
| } | |
| """ | |
| head = """ | |
| <!-- Google tag (gtag.js) --> | |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-P13E18K71N"></script> | |
| <script> | |
| window.dataLayer = window.dataLayer || []; | |
| function gtag(){dataLayer.push(arguments);} | |
| gtag('js', new Date()); | |
| gtag('config', 'G-P13E18K71N'); | |
| </script> | |
| """ | |
| # with gr.Blocks(head=head, title="TIPSv2 Feature Explorer") as demo: | |
| with gr.Blocks(head=head, title="TIPSv2 Feature Explorer", css=custom_css) as demo: | |
| gr.Markdown( | |
| "## TIPSv2 Feature Explorer\n" | |
| "Explore TIPSv2 representations here! For more information, see: " | |
| "https://gdm-tipsv2.github.io/" | |
| ) | |
| with gr.Row(): | |
| variant_dd = gr.Dropdown( | |
| choices=list(VARIANTS.keys()), | |
| value=DEFAULT_VARIANT, | |
| label="Model variant", | |
| ) | |
| resolution_dd = gr.Dropdown( | |
| choices=RESOLUTIONS, | |
| value=DEFAULT_IMAGE_SIZE, | |
| label="Resolution (higher = better quality, slower)", | |
| ) | |
| # ββ PCA / Feature Visualization Tab βββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¨ PCA & Feature Visualization"): | |
| pca_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| pca_input = gr.Image(type="numpy", label="Input image") | |
| pca_btn = gr.Button("Extract Features", variant="primary") | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("PCA"): | |
| pca_out = gr.Image(label="PCA (3 components β RGB)", height=448, elem_id="pca_output_image") | |
| # pca_out = gr.Image(label="PCA (3 components β RGB)") | |
| with gr.Tab("PCA (1st component)"): | |
| depth_out = gr.Image(label="1st PCA component") | |
| with gr.Tab("K-means Clustering"): | |
| n_clusters = gr.Slider(2, 20, value=6, step=1, label="Clusters") | |
| recluster_btn = gr.Button("Re-cluster") | |
| kmeans_out = gr.Image(label="K-means clusters") | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/pca/hike.jpeg"], | |
| ["examples/pca/cph.jpeg"], | |
| ["examples/pca/angus.jpeg"], | |
| ["examples/pca/dadaocheng.jpeg"], | |
| ], | |
| inputs=[pca_input], | |
| ) | |
| # ββ Zero-shot Segmentation Tab ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βοΈ Zero-shot Segmentation"): | |
| gr.Markdown( | |
| "Define your own classes for zero-shot segmentation. " | |
| "Enter class names separated by commas." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| custom_input = gr.Image(type="numpy", label="Input image", height=448) | |
| custom_classes = gr.Textbox( | |
| label="Class names (comma-separated)", | |
| value="class1, class2, class3", | |
| placeholder="e.g. cat, dog, sky, grass", | |
| ) | |
| custom_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("Overlay"): | |
| custom_overlay = gr.Image(label="Segmentation overlay", height=448) | |
| with gr.Tab("Mask"): | |
| custom_mask = gr.Image(label="Segmentation mask", height=448) | |
| custom_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=2) | |
| custom_undetected = gr.Textbox(label="Not detected", lines=2) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/zeroseg_voc/voc_2008_000891.jpg", "dog, cage, cloth, dog bowl"], | |
| ["examples/zeroseg/pascal_context_00000_image.png", "bike, tree, fence, soccer, floor, chair, cushion"], | |
| ["examples/zeroseg/pascal_context_00007_image.png", "dog, table, chair, carpet, shoes"], | |
| ["examples/zeroseg/pascal_context_00049_image.png", "bus, snow, mountain, house, road"], | |
| ], | |
| inputs=[custom_input, custom_classes], | |
| ) | |
| # ββ Pascal Context & VOC tabs (hidden for now) ββββββββββββββββββββββ | |
| # with gr.Tab("πΊοΈ Zero-shot Segmentation (Pascal Context)"): | |
| # gr.Markdown( | |
| # "Zero-shot semantic segmentation using **Value Attention** features " | |
| # "(MaskCLIP style) with 9 TCL prompt templates ensembled over " | |
| # "**Pascal Context 59 classes**. No segmentation head β purely from " | |
| # "visionβlanguage alignment." | |
| # ) | |
| # | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # zs_input = gr.Image(type="numpy", label="Input image", height=448) | |
| # zs_btn = gr.Button("Segment", variant="primary") | |
| # | |
| # with gr.Column(): | |
| # with gr.Tabs(): | |
| # with gr.Tab("Overlay"): | |
| # zs_overlay = gr.Image(label="Segmentation overlay", height=448) | |
| # with gr.Tab("Mask"): | |
| # zs_mask = gr.Image(label="Segmentation mask (raw)", height=448) | |
| # zs_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=3) | |
| # zs_undetected = gr.Textbox(label="Not detected", lines=2) | |
| # | |
| # gr.Examples( | |
| # examples=[ | |
| # ["examples/zeroseg/bus.png"], | |
| # ["examples/zeroseg/birds.png"], | |
| # ["examples/zeroseg/bicycle.png"], | |
| # ["examples/zeroseg/baby.png"], | |
| # ["examples/zeroseg/dog.png"], | |
| # ["examples/zeroseg/sleep.png"], | |
| # ["examples/zeroseg/pascal_context_00007.png"], | |
| # ["examples/zeroseg/pascal_context_00029.png"], | |
| # ["examples/zeroseg/pascal_context_00068.png"], | |
| # ], | |
| # inputs=[zs_input], | |
| # ) | |
| # ββ Zero-shot Segmentation (Pascal VOC) Tab βββββββββββββββββββββββββ | |
| # with gr.Tab("π·οΈ Zero-shot Segmentation (Pascal VOC)"): | |
| # gr.Markdown( | |
| # "Zero-shot semantic segmentation using **Value Attention** features " | |
| # "(MaskCLIP style) with 9 TCL prompt templates ensembled over " | |
| # "**Pascal VOC 20 classes**. No segmentation head β purely from " | |
| # "visionβlanguage alignment." | |
| # ) | |
| # | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # voc_input = gr.Image(type="numpy", label="Input image", height=448) | |
| # voc_btn = gr.Button("Segment", variant="primary") | |
| # | |
| # with gr.Column(): | |
| # with gr.Tabs(): | |
| # with gr.Tab("Overlay"): | |
| # voc_overlay = gr.Image(label="Segmentation overlay", height=448) | |
| # with gr.Tab("Mask"): | |
| # voc_mask = gr.Image(label="Segmentation mask (raw)", height=448) | |
| # voc_detected = gr.Textbox(label="Detected classes (sorted by area)", lines=3) | |
| # voc_undetected = gr.Textbox(label="Not detected", lines=2) | |
| # | |
| # gr.Examples( | |
| # examples=[ | |
| # ["examples/zeroseg_voc/voc_2008_000012.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000044.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000159.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000167.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000712.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000768.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_000891.jpg"], | |
| # ["examples/zeroseg_voc/voc_2008_001365.jpg"], | |
| # ], | |
| # inputs=[voc_input], | |
| # ) | |
| # ββ Zero-shot Segmentation (Custom) Tab ββββββββββββββββββββββββββββββ | |
| # ββ Depth Feature Visualization Tab βββββββββββββββββββββββββββββββββ | |
| with gr.Tab("ποΈ Depth/Normals Visualization"): | |
| gr.Markdown( | |
| "Monocular depth and surface normals estimation using a **DPT " | |
| "(Dense Prediction Transformer)** head on top of a **frozen** " | |
| "TIPS v2 vision encoder. Trained on the **NYU Depth V2** dataset." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| depth_input = gr.Image(type="numpy", label="Input image", height=448) | |
| depth_btn = gr.Button("Predict Depth & Normals", variant="primary") | |
| with gr.Column(): | |
| dpt_depth_out = gr.Image(label="DPT Depth Map", height=448) | |
| with gr.Column(): | |
| dpt_normals_out = gr.Image(label="DPT Surface Normals", height=448) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/nyuv2/bedroom_00280.jpg"], | |
| ["examples/nyuv2/kitchen_00249.jpg"], | |
| ["examples/nyuv2/living_room_01260.jpg"], | |
| ["examples/nyuv2/office_kitchen_00413.jpg"], | |
| ["examples/nyuv2/study_room_00272.jpg"], | |
| ], | |
| inputs=[depth_input], | |
| ) | |
| # ββ Supervised Segmentation Tab ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Supervised Segmentation"): | |
| gr.Markdown( | |
| "Semantic segmentation using a **DPT (Dense Prediction " | |
| "Transformer)** head on top of a **frozen** TIPS v2 vision " | |
| "encoder. Trained on ADE20K (150 classes)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| seg_input = gr.Image(type="numpy", label="Input image", height=448) | |
| seg_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| seg_out = gr.Image(label="DPT Segmentation (ADE20K)", height=448) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/depth/ade20k_00003.png"], | |
| ["examples/depth/ade20k_00007.png"], | |
| ["examples/depth/ade20k_00014.png"], | |
| ["examples/depth/ade20k_00022.png"], | |
| ], | |
| inputs=[seg_input], | |
| ) | |
| # ββ Wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| variant_dd.change( | |
| fn=on_variant_change, | |
| inputs=[variant_dd], | |
| outputs=[ | |
| pca_out, depth_out, kmeans_out, # PCA outputs | |
| pca_state, # PCA state | |
| custom_overlay, custom_mask, # Custom ZS outputs | |
| custom_detected, custom_undetected, # Custom ZS text | |
| ], | |
| ) | |
| pca_btn.click( | |
| fn=on_pca_extract, | |
| inputs=[pca_input, resolution_dd, pca_state], | |
| outputs=[pca_out, depth_out, kmeans_out, pca_state], | |
| ) | |
| recluster_btn.click( | |
| fn=on_recluster, | |
| inputs=[pca_input, resolution_dd, n_clusters, pca_state], | |
| outputs=[kmeans_out, pca_state], | |
| ) | |
| # zs_btn.click( | |
| # fn=on_zeroseg, | |
| # inputs=[zs_input, resolution_dd], | |
| # outputs=[zs_overlay, zs_mask, zs_detected, zs_undetected], | |
| # ) | |
| depth_btn.click( | |
| fn=on_depth_normals_predict, | |
| inputs=[depth_input, variant_dd, resolution_dd], | |
| outputs=[dpt_depth_out, dpt_normals_out], | |
| ) | |
| seg_btn.click( | |
| fn=on_segmentation_predict, | |
| inputs=[seg_input, variant_dd, resolution_dd], | |
| outputs=[seg_out], | |
| ) | |
| # voc_btn.click( | |
| # fn=on_zeroseg_voc, | |
| # inputs=[voc_input, resolution_dd], | |
| # outputs=[voc_overlay, voc_mask, voc_detected, voc_undetected], | |
| # ) | |
| custom_btn.click( | |
| fn=on_zeroseg_custom, | |
| inputs=[custom_input, resolution_dd, custom_classes], | |
| outputs=[custom_overlay, custom_mask, custom_detected, custom_undetected], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |