Use AutoModel for model loading, remove 2200+ LOC of dead code, add DPT seg legend, add smaller resolutions

#2
by gberton - opened
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.15.9/8006769214093067198 ADDED
Binary file (61 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
app.py CHANGED
@@ -1,9 +1,6 @@
1
  """TIPS Feature Explorer (GPU) β€” Hugging Face Space demo with ZeroGPU."""
2
 
3
  import colorsys
4
- import io
5
- import os
6
- import urllib.request
7
 
8
  import gradio as gr
9
  import matplotlib.cm as cm
@@ -16,97 +13,35 @@ from PIL import Image, ImageDraw, ImageFont
16
  from fast_pytorch_kmeans import KMeans as TorchKMeans
17
  from sklearn.decomposition import PCA
18
  from torchvision import transforms
19
-
20
- import dpt_head
21
- import image_encoder
22
- import text_encoder as text_encoder_mod
23
 
24
  # ── Constants ───────────────────────────────────────────────────────────────
25
 
26
  DEFAULT_IMAGE_SIZE = 896
27
- MODEL_IMAGE_SIZE = 448
28
  PATCH_SIZE = 14
29
- RESOLUTIONS = [896, 1120, 1372, 1792]
30
 
31
  ZEROSEG_IMAGE_SIZE = 1372
32
- ZEROSEG_SPATIAL = ZEROSEG_IMAGE_SIZE // PATCH_SIZE # 96
33
- DEPTH_IMAGE_SIZE = 1036 # must be divisible by PATCH_SIZE=14 β†’ 74Γ—14
34
- DEPTH_SPATIAL = DEPTH_IMAGE_SIZE // PATCH_SIZE # 74
35
- VOCAB_SIZE = 32000
36
  MAX_LEN = 64
37
- CKPT_DIR = "checkpoints"
38
- GCS = "https://storage.googleapis.com/tips_data"
39
-
40
- # Per-variant DPT config: embed_dim, block_indices, checkpoint URLs
41
- DPT_CONFIGS = {
42
- "TIPS v2 β€” B/14": dict(
43
- embed_dim=768, block_indices=[2, 5, 8, 11],
44
- depth_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_b14_depth_dpt.zip",
45
- normals_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_b14_normals_dpt.zip",
46
- seg_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_b14_segmentation_dpt.zip",
47
- ),
48
- "TIPS v2 β€” L/14": dict(
49
- embed_dim=1024, block_indices=[5, 11, 17, 23],
50
- depth_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_l14_depth_dpt.zip",
51
- normals_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_l14_normals_dpt.zip",
52
- seg_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_l14_segmentation_dpt.zip",
53
- ),
54
- "TIPS v2 β€” SO400m/14": dict(
55
- embed_dim=1152, block_indices=[6, 13, 20, 26],
56
- depth_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_so400m14_depth_dpt.zip",
57
- normals_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_so400m14_normals_dpt.zip",
58
- seg_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_so400m14_segmentation_dpt.zip",
59
- ),
60
- "TIPS v2 β€” g/14": dict(
61
- embed_dim=1536, block_indices=[9, 19, 29, 39],
62
- depth_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_g14_depth_dpt.zip",
63
- normals_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_g14_normals_dpt.zip",
64
- seg_url=f"{GCS}/v2_0/checkpoints/scenic/tips_v2_g14_segmentation_dpt.zip",
65
- ),
66
- }
67
- DPT_VARIANT_CHOICES = list(DPT_CONFIGS.keys())
68
- DEFAULT_DPT_VARIANT = "TIPS v2 β€” L/14"
69
-
70
-
71
- def _device():
72
- """Resolve device dynamically β€” GPU is only available inside @spaces.GPU."""
73
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
-
75
- # ── Model variants ──────────────────────────────────────────────────────────
76
 
 
77
  VARIANTS = {
78
- "TIPS v2 β€” B/14": dict(
79
- vision_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_b14_vision.npz",
80
- text_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_b14_text.npz",
81
- vision_fn="vit_base",
82
- text_cfg=dict(hidden_size=768, mlp_dim=3072, num_heads=12, num_layers=12),
83
- ffn="mlp",
84
- ),
85
- "TIPS v2 β€” L/14": dict(
86
- vision_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_l14_vision.npz",
87
- text_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_l14_text.npz",
88
- vision_fn="vit_large",
89
- text_cfg=dict(hidden_size=1024, mlp_dim=4096, num_heads=16, num_layers=12),
90
- ffn="mlp",
91
- ),
92
- "TIPS v2 β€” SO400m/14": dict(
93
- vision_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_so14_vision.npz",
94
- text_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_so14_text.npz",
95
- vision_fn="vit_so400m",
96
- text_cfg=dict(hidden_size=1152, mlp_dim=4304, num_heads=16, num_layers=27),
97
- ffn="mlp",
98
- ),
99
- "TIPS v2 β€” g/14": dict(
100
- vision_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_g14_vision.npz",
101
- text_url=f"{GCS}/v2_0/checkpoints/pytorch/tips_v2_oss_g14_text.npz",
102
- vision_fn="vit_giant2",
103
- text_cfg=dict(hidden_size=1536, mlp_dim=6144, num_heads=24, num_layers=12),
104
- ffn="swiglu",
105
- ),
106
  }
107
-
108
  DEFAULT_VARIANT = "TIPS v2 β€” L/14"
109
 
 
 
 
110
  # ── Pascal Context (59 classes) ─────────────────────────────────────────────
111
 
112
  # TCL prompt templates (from the Scenic zero-shot seg evaluator).
@@ -135,57 +70,6 @@ PASCAL_CONTEXT_CLASSES = (
135
  "wood",
136
  )
137
 
138
- # ── Pascal VOC (20 foreground classes) ──────────────────────────────────────
139
-
140
- PASCAL_VOC_CLASSES = (
141
- "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
142
- "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
143
- "pottedplant", "sheep", "sofa", "train", "tvmonitor",
144
- )
145
-
146
- PASCAL_VOC_PALETTE = np.array([
147
- [128, 0, 0], # aeroplane
148
- [0, 128, 0], # bicycle
149
- [128, 128, 0], # bird
150
- [0, 0, 128], # boat
151
- [128, 0, 128], # bottle
152
- [0, 128, 128], # bus
153
- [128, 128, 128], # car
154
- [64, 0, 0], # cat
155
- [192, 0, 0], # chair
156
- [64, 128, 0], # cow
157
- [192, 128, 0], # diningtable
158
- [64, 0, 128], # dog
159
- [192, 0, 128], # horse
160
- [64, 128, 128], # motorbike
161
- [192, 128, 128], # person
162
- [0, 64, 0], # pottedplant
163
- [128, 64, 0], # sheep
164
- [0, 192, 0], # sofa
165
- [128, 192, 0], # train
166
- [0, 64, 128], # tvmonitor
167
- ], dtype=np.uint8)
168
-
169
- # Colors from segmentation_dataset_info.py (matching class order above,
170
- # i.e. index 0 = aeroplane, etc.).
171
- PASCAL_CONTEXT_PALETTE = np.array([
172
- [128, 0, 0], [214, 35, 42], [142, 28, 102], [39, 158, 136],
173
- [195, 112, 211], [0, 128, 0], [128, 128, 0], [0, 0, 128],
174
- [127, 34, 91], [128, 0, 128], [83, 137, 118], [0, 128, 128],
175
- [165, 86, 86], [128, 128, 128], [64, 0, 0], [106, 30, 114],
176
- [192, 0, 0], [226, 154, 154], [67, 11, 127], [64, 128, 0],
177
- [14, 242, 18], [155, 9, 121], [64, 0, 128], [131, 76, 67],
178
- [229, 106, 184], [37, 131, 150], [160, 150, 59], [154, 176, 215],
179
- [255, 255, 222], [106, 160, 142], [192, 0, 128], [214, 35, 42],
180
- [141, 90, 178], [64, 128, 128], [229, 106, 184], [116, 116, 116],
181
- [192, 128, 128], [0, 182, 198], [21, 106, 168], [0, 64, 0],
182
- [6, 151, 48], [214, 35, 42], [128, 64, 0], [131, 76, 67],
183
- [229, 106, 184], [116, 116, 116], [0, 182, 198], [0, 182, 198],
184
- [0, 192, 0], [255, 117, 39], [6, 151, 48], [128, 192, 0],
185
- [141, 90, 178], [131, 76, 6], [0, 64, 128], [116, 116, 116],
186
- [178, 182, 50], [0, 182, 198], [21, 106, 168],
187
- ], dtype=np.uint8)
188
-
189
  ADE20K_CLASSES = (
190
  'wall', 'building', 'sky', 'floor', 'tree',
191
  'ceiling', 'road', 'bed', 'windowpane', 'grass',
@@ -238,144 +122,49 @@ _model = {
238
  "text": None,
239
  "tokenizer": None,
240
  "temperature": None,
241
- "ade20k_embs": None, # (59, D) pre-computed text embeddings for Pascal Context
242
- "voc_embs": None, # (20, D) pre-computed text embeddings for Pascal VOC
243
  }
244
 
245
- # DPT depth head β€” keyed per variant
246
  _dpt = {
247
- "variant": None, # currently loaded DPT variant name
248
- "model": None, # DPTDepthHead on CPU
249
- "normals_model": None, # DPTNormalsHead on CPU
250
- "segmentation_model": None, # DPTSegmentationHead on CPU
251
- "vision": None, # vision encoder for current DPT variant
252
  }
253
 
254
-
255
- def _download(url):
256
- """Download a file to CKPT_DIR if not already present. Return local path."""
257
- fname = url.rsplit("/", 1)[-1]
258
- path = os.path.join(CKPT_DIR, fname)
259
- if not os.path.exists(path):
260
- print(f"Downloading {fname} ...")
261
- urllib.request.urlretrieve(url, path)
262
- return path
263
-
264
-
265
  def load_variant(name):
266
- """Download (if needed) and load a model variant.
267
-
268
- Models are kept on CPU for storage. They are moved to GPU dynamically
269
- inside @spaces.GPU-decorated callbacks via _move_models_to_device().
270
- """
271
  global _model
272
  if _model["name"] == name:
273
  return
274
- os.makedirs(CKPT_DIR, exist_ok=True)
275
-
276
- cfg = VARIANTS[name]
277
-
278
- # -- vision encoder (load on CPU) --
279
- vis_path = _download(cfg["vision_url"])
280
- weights_v = {k: torch.tensor(v) for k, v in np.load(vis_path, allow_pickle=False).items()}
281
- build_vision = getattr(image_encoder, cfg["vision_fn"])
282
- model_v = build_vision(
283
- img_size=MODEL_IMAGE_SIZE, patch_size=PATCH_SIZE, ffn_layer=cfg["ffn"],
284
- block_chunks=0, init_values=1.0,
285
- interpolate_antialias=True, interpolate_offset=0.0,
286
- )
287
- model_v.load_state_dict(weights_v)
288
- model_v.eval()
289
-
290
- # -- text encoder (load on CPU) --
291
- txt_path = _download(cfg["text_url"])
292
- with open(txt_path, "rb") as f:
293
- weights_t = {k: torch.from_numpy(v) for k, v in np.load(io.BytesIO(f.read()), allow_pickle=False).items()}
294
- temperature = weights_t.pop("temperature")
295
- model_t = text_encoder_mod.TextEncoder(cfg["text_cfg"], vocab_size=VOCAB_SIZE)
296
- model_t.load_state_dict(weights_t)
297
- model_t.eval()
298
-
299
- # -- tokenizer (shared across variants) --
300
- tok_path = _download(f"{GCS}/v1_0/checkpoints/tokenizer.model")
301
- tokenizer = text_encoder_mod.Tokenizer(tok_path)
302
-
303
  _model.update(
304
- name=name, vision=model_v, text=model_t,
305
- tokenizer=tokenizer, temperature=temperature,
306
- ade20k_embs=None, # computed lazily on GPU
 
 
 
 
 
307
  )
308
- print(f"Loaded {name} (on CPU, will move to GPU on demand)")
309
-
310
 
311
  def _load_dpt(variant_name=None):
312
- """Download and build DPT heads + vision encoder for the given variant."""
313
  global _dpt
314
  if variant_name is None:
315
- variant_name = DEFAULT_DPT_VARIANT
316
- cfg = DPT_CONFIGS[variant_name]
317
- embed_dim = cfg["embed_dim"]
318
-
319
- # Skip reload if same variant is already loaded
320
  if _dpt["variant"] == variant_name and _dpt["model"] is not None:
321
  return
322
-
323
- os.makedirs(CKPT_DIR, exist_ok=True)
324
-
325
- # Load DPT depth head
326
- zip_path = _download(cfg["depth_url"])
327
- dpt_model = dpt_head.DPTDepthHead(
328
- input_embed_dim=embed_dim, channels=256,
329
- post_process_channels=(128, 256, 512, 1024),
330
- readout_type="project", num_depth_bins=256,
331
- min_depth=1e-3, max_depth=10.0,
332
- )
333
- dpt_head.load_dpt_weights(dpt_model, zip_path)
334
- dpt_model.eval()
335
- _dpt["model"] = dpt_model
336
-
337
- # Load DPT normals head
338
- normals_zip = _download(cfg["normals_url"])
339
- normals_model = dpt_head.DPTNormalsHead(
340
- input_embed_dim=embed_dim, channels=256,
341
- post_process_channels=(128, 256, 512, 1024),
342
- readout_type="project",
343
- )
344
- dpt_head.load_normals_weights(normals_model, normals_zip)
345
- normals_model.eval()
346
- _dpt["normals_model"] = normals_model
347
-
348
- # Load DPT segmentation head
349
- seg_zip = _download(cfg["seg_url"])
350
- seg_model = dpt_head.DPTSegmentationHead(
351
- input_embed_dim=embed_dim, channels=256,
352
- post_process_channels=(128, 256, 512, 1024),
353
- readout_type="project", num_classes=150,
354
- )
355
- dpt_head.load_segmentation_weights(seg_model, seg_zip)
356
- seg_model.eval()
357
- _dpt["segmentation_model"] = seg_model
358
-
359
- # Vision encoder β€” reuse if the main model matches
360
- var_cfg = VARIANTS[variant_name]
361
  if _model["name"] == variant_name and _model["vision"] is not None:
362
- vision = _model["vision"]
363
- else:
364
- vis_path = _download(var_cfg["vision_url"])
365
- weights_v = {k: torch.tensor(v) for k, v in np.load(vis_path, allow_pickle=False).items()}
366
- build_fn = getattr(image_encoder, var_cfg["vision_fn"])
367
- vision = build_fn(
368
- img_size=MODEL_IMAGE_SIZE, patch_size=PATCH_SIZE,
369
- ffn_layer=var_cfg["ffn"], block_chunks=0, init_values=1.0,
370
- interpolate_antialias=True, interpolate_offset=0.0,
371
- )
372
- vision.load_state_dict(weights_v)
373
- vision.eval()
374
- _dpt["vision"] = vision
375
- _dpt["variant"] = variant_name
376
-
377
- print(f"Loaded DPT heads + {variant_name} vision encoder (on CPU)")
378
-
379
 
380
  def _move_models_to_device():
381
  """Move models to the current device (GPU inside @spaces.GPU, else CPU)."""
@@ -385,7 +174,6 @@ def _move_models_to_device():
385
  if _model["text"] is not None:
386
  _model["text"].to(dev)
387
 
388
-
389
  def _ensure_ade20k_embs():
390
  """Pre-compute Pascal Context text embeddings if not yet done (must run on GPU)."""
391
  if _model["ade20k_embs"] is not None:
@@ -403,32 +191,11 @@ def _ensure_ade20k_embs():
403
  _model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0))
404
  print("Pascal Context text embeddings computed.")
405
 
406
-
407
- def _ensure_voc_embs():
408
- """Pre-compute Pascal VOC text embeddings if not yet done (must run on GPU)."""
409
- if _model["voc_embs"] is not None:
410
- return
411
- dev = _device()
412
- model_t = _model["text"]
413
- tokenizer = _model["tokenizer"]
414
- all_embs = []
415
- for template in TCL_PROMPTS:
416
- prompts = [template.format(c) for c in PASCAL_VOC_CLASSES]
417
- ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN)
418
- with torch.no_grad():
419
- embs = model_t(torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev))
420
- all_embs.append(embs.cpu().numpy())
421
- _model["voc_embs"] = l2_normalize(np.mean(all_embs, axis=0))
422
- print("Pascal VOC text embeddings computed.")
423
-
424
-
425
  def _init_model():
426
  """Load model + move to GPU + compute text embeddings."""
427
  load_variant(_model["name"] or DEFAULT_VARIANT)
428
  _move_models_to_device()
429
  _ensure_ade20k_embs()
430
- _ensure_voc_embs()
431
-
432
 
433
  # ── Preprocessing & helpers ─────────────────────────────────────────────────
434
 
@@ -438,16 +205,9 @@ def preprocess(img, size=DEFAULT_IMAGE_SIZE):
438
  transforms.ToTensor(),
439
  ])(img)
440
 
441
- preprocess_zeroseg = transforms.Compose([
442
- transforms.Resize((ZEROSEG_IMAGE_SIZE, ZEROSEG_IMAGE_SIZE)),
443
- transforms.ToTensor(),
444
- ])
445
-
446
-
447
  def l2_normalize(x, axis=-1):
448
  return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3)
449
 
450
-
451
  def upsample(arr, h, w, mode="bilinear"):
452
  """Upsample (H, W, C) or (H, W) numpy array to (h, w, ...)."""
453
  t = torch.from_numpy(arr).float()
@@ -458,11 +218,9 @@ def upsample(arr, h, w, mode="bilinear"):
458
  up = F.interpolate(t, size=(h, w), mode=mode, **kwargs)
459
  return up[0].permute(1, 2, 0).numpy()
460
 
461
-
462
  def to_uint8(x):
463
  return (x * 255).clip(0, 255).astype(np.uint8)
464
 
465
-
466
  # ── Feature extraction (GPU-accelerated) ────────────────────────────────────
467
 
468
  @torch.no_grad()
@@ -475,7 +233,6 @@ def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE):
475
  sp = resolution // PATCH_SIZE
476
  return patch_tokens.cpu().reshape(sp, sp, -1).numpy()
477
 
478
-
479
  @torch.no_grad()
480
  def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE):
481
  """Return spatial features (sp, sp, D) using Value Attention on GPU.
@@ -527,7 +284,6 @@ def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE):
527
  spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy()
528
  return spatial
529
 
530
-
531
  # ── PCA Visualisations ──────────────────────────────────────────────────────
532
 
533
  def vis_pca(spatial, h, w):
@@ -540,9 +296,8 @@ def vis_pca(spatial, h, w):
540
  rgb = 1 / (1 + np.exp(-2.0 * rgb))
541
  return to_uint8(upsample(rgb, h, w))
542
 
543
-
544
  def vis_depth(spatial, h, w):
545
- """1st PCA component as pseudo-depth (inferno colormap)."""
546
  feat = spatial.reshape(-1, spatial.shape[-1])
547
  H, W = spatial.shape[0], spatial.shape[1]
548
  depth = PCA(n_components=1).fit_transform(feat).reshape(H, W)
@@ -550,7 +305,6 @@ def vis_depth(spatial, h, w):
550
  colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32)
551
  return to_uint8(upsample(colored, h, w))
552
 
553
-
554
  def vis_kmeans(spatial, h, w, n_clusters=6):
555
  """K-means clustering of spatial features."""
556
  H, W = spatial.shape[:2]
@@ -566,186 +320,8 @@ def vis_kmeans(spatial, h, w, n_clusters=6):
566
  seg = palette[labels].astype(np.float32)
567
  return to_uint8(seg)
568
 
569
-
570
  # ── Zero-shot Segmentation ──────────────────────────────────────────────────
571
 
572
- def vis_pascal_context_semseg(spatial, orig_image):
573
- """Zero-shot semantic segmentation with Pascal Context 59 classes.
574
-
575
- Uses value-attention features and TCL prompt templates (9-template
576
- ensemble) following the Scenic zero-shot seg evaluator.
577
-
578
- For each spatial position, pick the Pascal Context class whose text
579
- embedding has the highest cosine similarity with the image feature.
580
- Returns (labelled image, raw mask, detected string, undetected string).
581
- """
582
- h, w = orig_image.shape[:2]
583
- S_h, S_w = spatial.shape[:2]
584
- feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1])) # (N, D)
585
- sim = feat @ _model["ade20k_embs"].T # (N, 59)
586
- sim_map = sim.reshape(S_h, S_w, -1)
587
-
588
- # Bilinear upsample similarities then argmax for smooth boundaries
589
- sim_up = upsample(sim_map, h, w, mode="bilinear")
590
- labels = sim_up.argmax(axis=-1) # (h, w)
591
-
592
- # --- raw segmentation mask (no blend) ---
593
- seg_rgb = PASCAL_CONTEXT_PALETTE[labels].astype(np.float32) / 255.0
594
- mask_img = to_uint8(seg_rgb)
595
-
596
- # --- blended overlay with legend ---
597
- blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb
598
- blend_img = Image.fromarray(to_uint8(blend))
599
-
600
- # count pixels per class, sorted by area (descending)
601
- unique_ids, counts = np.unique(labels, return_counts=True)
602
- order = np.argsort(-counts)
603
- unique_ids, counts = unique_ids[order], counts[order]
604
- total = counts.sum()
605
-
606
- # build a legend panel on the right side
607
- try:
608
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60)
609
- except OSError:
610
- font = ImageFont.load_default()
611
-
612
- # show top 5 classes by area
613
- n_legend = min(len(unique_ids), 5)
614
- legend_ids = [(unique_ids[i], counts[i]) for i in range(n_legend)]
615
- row_h = 80 # height per legend row
616
- swatch_w = 60 # color swatch width
617
- pad = 12 # padding
618
- legend_w = 450 # legend panel width
619
-
620
- legend_h = max(h, n_legend * row_h + pad * 2)
621
- canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
622
- canvas.paste(blend_img, (0, 0))
623
- draw = ImageDraw.Draw(canvas)
624
-
625
- for i, (cid, cnt) in enumerate(legend_ids):
626
- pct = cnt / total * 100
627
- color = tuple(PASCAL_CONTEXT_PALETTE[cid].tolist())
628
- name = PASCAL_CONTEXT_CLASSES[cid]
629
-
630
- y_top = pad + i * row_h
631
- # draw color swatch
632
- draw.rectangle(
633
- [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w],
634
- fill=color, outline=(0, 0, 0),
635
- )
636
- # draw class name + percentage
637
- draw.text(
638
- (w + pad + swatch_w + 8, y_top + 6),
639
- f"{name}",
640
- fill="black", font=font,
641
- )
642
-
643
- overlay_out = np.array(canvas)
644
-
645
- # format detected (>=2%) / undetected (<2% or absent) strings
646
- detected_parts, minor_parts = [], []
647
- for i, cid in enumerate(unique_ids):
648
- pct = counts[i] / total * 100
649
- name = PASCAL_CONTEXT_CLASSES[cid]
650
- if pct >= 2:
651
- detected_parts.append(f"{name} ({pct:.1f}%)")
652
- else:
653
- minor_parts.append(f"{name} ({pct:.1f}%)")
654
- absent = [
655
- f"{PASCAL_CONTEXT_CLASSES[i]} (0.0%)"
656
- for i in range(len(PASCAL_CONTEXT_CLASSES))
657
- if i not in set(unique_ids.tolist())
658
- ]
659
- detected_str = ", ".join(detected_parts)
660
- undetected_str = ", ".join(minor_parts + absent)
661
- return overlay_out, mask_img, detected_str, undetected_str
662
-
663
-
664
- def vis_pascal_voc_semseg(spatial, orig_image):
665
- """Zero-shot semantic segmentation with Pascal VOC 20 classes.
666
-
667
- Same approach as Pascal Context but with VOC classes and palette.
668
- Returns (labelled image, raw mask, detected string, undetected string).
669
- """
670
- h, w = orig_image.shape[:2]
671
- S_h, S_w = spatial.shape[:2]
672
- feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1])) # (N, D)
673
- sim = feat @ _model["voc_embs"].T # (N, 20)
674
- sim_map = sim.reshape(S_h, S_w, -1)
675
-
676
- # Bilinear upsample similarities then argmax for smooth boundaries
677
- sim_up = upsample(sim_map, h, w, mode="bilinear")
678
- labels = sim_up.argmax(axis=-1) # (h, w)
679
-
680
- # --- raw segmentation mask (no blend) ---
681
- seg_rgb = PASCAL_VOC_PALETTE[labels].astype(np.float32) / 255.0
682
- mask_img = to_uint8(seg_rgb)
683
-
684
- # --- blended overlay with legend ---
685
- blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb
686
- blend_img = Image.fromarray(to_uint8(blend))
687
-
688
- # count pixels per class, sorted by area (descending)
689
- unique_ids, counts = np.unique(labels, return_counts=True)
690
- order = np.argsort(-counts)
691
- unique_ids, counts = unique_ids[order], counts[order]
692
- total = counts.sum()
693
-
694
- # build a legend panel on the right side
695
- try:
696
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60)
697
- except OSError:
698
- font = ImageFont.load_default()
699
-
700
- n_legend = min(len(unique_ids), 5)
701
- legend_ids = [(unique_ids[i], counts[i]) for i in range(n_legend)]
702
- row_h = 80
703
- swatch_w = 60
704
- pad = 12
705
- legend_w = 450
706
-
707
- legend_h = max(h, n_legend * row_h + pad * 2)
708
- canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
709
- canvas.paste(blend_img, (0, 0))
710
- draw = ImageDraw.Draw(canvas)
711
-
712
- for i, (cid, cnt) in enumerate(legend_ids):
713
- pct = cnt / total * 100
714
- color = tuple(PASCAL_VOC_PALETTE[cid].tolist())
715
- name = PASCAL_VOC_CLASSES[cid]
716
-
717
- y_top = pad + i * row_h
718
- draw.rectangle(
719
- [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w],
720
- fill=color, outline=(0, 0, 0),
721
- )
722
- draw.text(
723
- (w + pad + swatch_w + 8, y_top + 6),
724
- f"{name}",
725
- fill="black", font=font,
726
- )
727
-
728
- overlay_out = np.array(canvas)
729
-
730
- # format detected (>=2%) / undetected (<2% or absent) strings
731
- detected_parts, minor_parts = [], []
732
- for i, cid in enumerate(unique_ids):
733
- pct = counts[i] / total * 100
734
- name = PASCAL_VOC_CLASSES[cid]
735
- if pct >= 2:
736
- detected_parts.append(f"{name} ({pct:.1f}%)")
737
- else:
738
- minor_parts.append(f"{name} ({pct:.1f}%)")
739
- absent = [
740
- f"{PASCAL_VOC_CLASSES[i]} (0.0%)"
741
- for i in range(len(PASCAL_VOC_CLASSES))
742
- if i not in set(unique_ids.tolist())
743
- ]
744
- detected_str = ", ".join(detected_parts)
745
- undetected_str = ", ".join(minor_parts + absent)
746
- return overlay_out, mask_img, detected_str, undetected_str
747
-
748
-
749
  def vis_custom_semseg(spatial, orig_image, classes, class_embs):
750
  """Zero-shot semantic segmentation with user-defined classes."""
751
  h, w = orig_image.shape[:2]
@@ -820,15 +396,8 @@ def vis_custom_semseg(spatial, orig_image, classes, class_embs):
820
  undetected_str = ", ".join(minor_parts + absent)
821
  return overlay_out, mask_img, detected_str, undetected_str
822
 
823
-
824
  # ── DPT Depth Inference ─────────────────────────────────────────────────────
825
 
826
- preprocess_depth = transforms.Compose([
827
- transforms.Resize((DEPTH_IMAGE_SIZE, DEPTH_IMAGE_SIZE)),
828
- transforms.ToTensor(),
829
- ])
830
-
831
-
832
  def vis_depth_dpt(depth_map, h, w):
833
  """Colour a depth map with the turbo colormap β†’ PIL Image."""
834
  d = depth_map.squeeze()
@@ -836,7 +405,6 @@ def vis_depth_dpt(depth_map, h, w):
836
  colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32)
837
  return to_uint8(upsample(colored, h, w))
838
 
839
-
840
  def vis_normals_dpt(normals_map, h, w):
841
  """Map normals from [-1, 1] to [0, 1] and resize to original size."""
842
  # normals_map shape is (3, H, W)
@@ -845,16 +413,43 @@ def vis_normals_dpt(normals_map, h, w):
845
  n = np.transpose(n, (1, 2, 0)) # (H, W, 3)
846
  return to_uint8(upsample(n, h, w))
847
 
848
-
849
- def vis_segmentation_dpt(seg_map, h, w):
850
- """Colour a segmentation map with the ADE20K colormap."""
851
- # seg_map shape is (150, H, W) β€” bilinear upsample logits then argmax
852
  logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150)
853
  logits_up = upsample(logits, h, w, mode="bilinear")
854
  pred = logits_up.argmax(axis=-1) # (h, w)
855
- colored = ADE20K_PALETTE[pred.astype(np.int32) + 1] # (h, w, 3)
856
- return colored
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
 
858
 
859
  # ── Gradio callbacks ────────────────────────────────────────────────────────
860
 
@@ -869,7 +464,6 @@ def on_variant_change(variant_name):
869
  None, # pca_state
870
  None, None, "", "") # custom outputs
871
 
872
-
873
  # --- PCA tab callbacks ---
874
 
875
  @spaces.GPU
@@ -886,7 +480,6 @@ def on_pca_extract(image, resolution, pca_state):
886
  state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution}
887
  return pca, depth, kmeans, state
888
 
889
-
890
  @spaces.GPU
891
  def on_recluster(image, resolution, n_clusters, pca_state):
892
  if image is None:
@@ -904,29 +497,8 @@ def on_recluster(image, resolution, n_clusters, pca_state):
904
  h, w = image.shape[:2]
905
  return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state
906
 
907
-
908
  # --- Zero-shot Segmentation tab callbacks ---
909
 
910
- @spaces.GPU
911
- def on_zeroseg(image, resolution):
912
- if image is None:
913
- return None, None, "", ""
914
- _init_model()
915
- spatial = extract_features_value_attention(image, int(resolution))
916
- blend, mask, detected, undetected = vis_pascal_context_semseg(spatial, image)
917
- return blend, mask, detected, undetected
918
-
919
-
920
- @spaces.GPU
921
- def on_zeroseg_voc(image, resolution):
922
- if image is None:
923
- return None, None, "", ""
924
- _init_model()
925
- spatial = extract_features_value_attention(image, int(resolution))
926
- blend, mask, detected, undetected = vis_pascal_voc_semseg(spatial, image)
927
- return blend, mask, detected, undetected
928
-
929
-
930
  @spaces.GPU
931
  def on_zeroseg_custom(image, resolution, class_names_str):
932
  if image is None or not class_names_str or not class_names_str.strip():
@@ -953,75 +525,41 @@ def on_zeroseg_custom(image, resolution, class_names_str):
953
  overlay, mask, detected, undetected = vis_custom_semseg(spatial, image, classes, class_embs)
954
  return overlay, mask, detected, undetected
955
 
956
-
957
  # --- Depth Feature Visualization tab callbacks ---
958
 
959
  @spaces.GPU
960
  def on_depth_normals_predict(image, dpt_variant, resolution):
961
- """Run DPT depth and normals prediction on an input image."""
962
  if image is None:
963
  return None, None
964
  _load_dpt(dpt_variant)
965
  dev = _device()
966
- block_indices = DPT_CONFIGS[dpt_variant]["block_indices"]
967
-
968
- # Move DPT models to GPU
969
- _dpt["model"].to(dev)
970
- _dpt["normals_model"].to(dev)
971
- _dpt["vision"].to(dev)
972
 
973
  h, w = image.shape[:2]
974
  img = Image.fromarray(image).convert("RGB")
975
  tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
976
 
977
- with torch.no_grad():
978
- intermediate = _dpt["vision"].get_intermediate_layers(
979
- tensor, n=block_indices,
980
- reshape=True, return_class_token=True, norm=True,
981
- )
982
- dpt_inputs = [(cls_tok, patch_feat)
983
- for patch_feat, cls_tok in intermediate]
984
-
985
- depth_map = _dpt["model"](dpt_inputs, image_size=(h, w))
986
- normals_map = _dpt["normals_model"](dpt_inputs, image_size=(h, w))
987
-
988
- depth_np = depth_map[0, 0].cpu().numpy()
989
- normals_np = normals_map[0]
990
-
991
- return vis_depth_dpt(depth_np, h, w), vis_normals_dpt(normals_np, h, w)
992
 
 
993
 
994
  @spaces.GPU
995
  def on_segmentation_predict(image, dpt_variant, resolution):
996
- """Run DPT segmentation prediction on an input image."""
997
  if image is None:
998
  return None
999
  _load_dpt(dpt_variant)
1000
  dev = _device()
1001
- block_indices = DPT_CONFIGS[dpt_variant]["block_indices"]
1002
-
1003
- # Move DPT models to GPU
1004
- _dpt["segmentation_model"].to(dev)
1005
- _dpt["vision"].to(dev)
1006
 
1007
  h, w = image.shape[:2]
1008
  img = Image.fromarray(image).convert("RGB")
1009
  tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
1010
 
1011
- with torch.no_grad():
1012
- intermediate = _dpt["vision"].get_intermediate_layers(
1013
- tensor, n=block_indices,
1014
- reshape=True, return_class_token=True, norm=True,
1015
- )
1016
- dpt_inputs = [(cls_tok, patch_feat)
1017
- for patch_feat, cls_tok in intermediate]
1018
-
1019
- seg_map = _dpt["segmentation_model"](dpt_inputs, image_size=(h, w))
1020
-
1021
- seg_np = seg_map[0]
1022
-
1023
- return vis_segmentation_dpt(seg_np, h, w)
1024
-
1025
 
1026
  # ── UI ──────────────────────────────────────────────────────────────────────
1027
 
@@ -1071,7 +609,7 @@ with gr.Blocks(head=head, title="TIPSv2 Feature Explorer") as demo:
1071
  with gr.Tab("PCA"):
1072
  pca_out = gr.Image(label="PCA (3 components β†’ RGB)")
1073
  with gr.Tab("PCA (1st component)"):
1074
- depth_out = gr.Image(label="Depth (1st PCA component)")
1075
  with gr.Tab("K-means Clustering"):
1076
  n_clusters = gr.Slider(2, 20, value=6, step=1, label="Clusters")
1077
  recluster_btn = gr.Button("Re-cluster")
 
1
  """TIPS Feature Explorer (GPU) β€” Hugging Face Space demo with ZeroGPU."""
2
 
3
  import colorsys
 
 
 
4
 
5
  import gradio as gr
6
  import matplotlib.cm as cm
 
13
  from fast_pytorch_kmeans import KMeans as TorchKMeans
14
  from sklearn.decomposition import PCA
15
  from torchvision import transforms
16
+ from transformers import AutoModel
 
 
 
17
 
18
  # ── Constants ───────────────────────────────────────────────────────────────
19
 
20
  DEFAULT_IMAGE_SIZE = 896
 
21
  PATCH_SIZE = 14
22
+ RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792]
23
 
24
  ZEROSEG_IMAGE_SIZE = 1372
 
 
 
 
25
  MAX_LEN = 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # HF model repos
28
  VARIANTS = {
29
+ "TIPS v2 β€” B/14": "google/tipsv2-b14",
30
+ "TIPS v2 β€” L/14": "google/tipsv2-l14",
31
+ "TIPS v2 β€” SO400m/14": "google/tipsv2-so400m14",
32
+ "TIPS v2 β€” g/14": "google/tipsv2-g14",
33
+ }
34
+ DPT_VARIANTS = {
35
+ "TIPS v2 β€” B/14": "google/tipsv2-b14-dpt",
36
+ "TIPS v2 β€” L/14": "google/tipsv2-l14-dpt",
37
+ "TIPS v2 β€” SO400m/14": "google/tipsv2-so400m14-dpt",
38
+ "TIPS v2 β€” g/14": "google/tipsv2-g14-dpt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
 
40
  DEFAULT_VARIANT = "TIPS v2 β€” L/14"
41
 
42
+ def _device():
43
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
  # ── Pascal Context (59 classes) ─────────────────────────────────────────────
46
 
47
  # TCL prompt templates (from the Scenic zero-shot seg evaluator).
 
70
  "wood",
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  ADE20K_CLASSES = (
74
  'wall', 'building', 'sky', 'floor', 'tree',
75
  'ceiling', 'road', 'bed', 'windowpane', 'grass',
 
122
  "text": None,
123
  "tokenizer": None,
124
  "temperature": None,
125
+ "ade20k_embs": None,
126
+ "_hf_model": None,
127
  }
128
 
 
129
  _dpt = {
130
+ "variant": None,
131
+ "model": None,
132
+ "_hf_dpt": None,
 
 
133
  }
134
 
 
 
 
 
 
 
 
 
 
 
 
135
  def load_variant(name):
136
+ """Load a model variant from HuggingFace."""
 
 
 
 
137
  global _model
138
  if _model["name"] == name:
139
  return
140
+ hf_model = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True)
141
+ hf_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  _model.update(
143
+ name=name,
144
+ vision=hf_model.vision_encoder,
145
+ text=hf_model.text_encoder,
146
+ tokenizer=hf_model._load_tokenizer(),
147
+ temperature=hf_model.config.temperature,
148
+ ade20k_embs=None,
149
+ voc_embs=None,
150
+ _hf_model=hf_model,
151
  )
152
+ print(f"Loaded {name}")
 
153
 
154
  def _load_dpt(variant_name=None):
155
+ """Load DPT heads from HuggingFace."""
156
  global _dpt
157
  if variant_name is None:
158
+ variant_name = DEFAULT_VARIANT
 
 
 
 
159
  if _dpt["variant"] == variant_name and _dpt["model"] is not None:
160
  return
161
+ hf_dpt = AutoModel.from_pretrained(DPT_VARIANTS[variant_name], trust_remote_code=True)
162
+ hf_dpt.eval()
163
+ # Reuse backbone from main model if variants match to save memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if _model["name"] == variant_name and _model["vision"] is not None:
165
+ hf_dpt._backbone = _model["_hf_model"]
166
+ _dpt.update(variant=variant_name, model=hf_dpt, _hf_dpt=hf_dpt)
167
+ print(f"Loaded DPT heads for {variant_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def _move_models_to_device():
170
  """Move models to the current device (GPU inside @spaces.GPU, else CPU)."""
 
174
  if _model["text"] is not None:
175
  _model["text"].to(dev)
176
 
 
177
  def _ensure_ade20k_embs():
178
  """Pre-compute Pascal Context text embeddings if not yet done (must run on GPU)."""
179
  if _model["ade20k_embs"] is not None:
 
191
  _model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0))
192
  print("Pascal Context text embeddings computed.")
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def _init_model():
195
  """Load model + move to GPU + compute text embeddings."""
196
  load_variant(_model["name"] or DEFAULT_VARIANT)
197
  _move_models_to_device()
198
  _ensure_ade20k_embs()
 
 
199
 
200
  # ── Preprocessing & helpers ─────────────────────────────────────────────────
201
 
 
205
  transforms.ToTensor(),
206
  ])(img)
207
 
 
 
 
 
 
 
208
  def l2_normalize(x, axis=-1):
209
  return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3)
210
 
 
211
  def upsample(arr, h, w, mode="bilinear"):
212
  """Upsample (H, W, C) or (H, W) numpy array to (h, w, ...)."""
213
  t = torch.from_numpy(arr).float()
 
218
  up = F.interpolate(t, size=(h, w), mode=mode, **kwargs)
219
  return up[0].permute(1, 2, 0).numpy()
220
 
 
221
  def to_uint8(x):
222
  return (x * 255).clip(0, 255).astype(np.uint8)
223
 
 
224
  # ── Feature extraction (GPU-accelerated) ────────────────────────────────────
225
 
226
  @torch.no_grad()
 
233
  sp = resolution // PATCH_SIZE
234
  return patch_tokens.cpu().reshape(sp, sp, -1).numpy()
235
 
 
236
  @torch.no_grad()
237
  def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE):
238
  """Return spatial features (sp, sp, D) using Value Attention on GPU.
 
284
  spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy()
285
  return spatial
286
 
 
287
  # ── PCA Visualisations ──────────────────────────────────────────────────────
288
 
289
  def vis_pca(spatial, h, w):
 
296
  rgb = 1 / (1 + np.exp(-2.0 * rgb))
297
  return to_uint8(upsample(rgb, h, w))
298
 
 
299
  def vis_depth(spatial, h, w):
300
+ """1st PCA component visualized with inferno colormap."""
301
  feat = spatial.reshape(-1, spatial.shape[-1])
302
  H, W = spatial.shape[0], spatial.shape[1]
303
  depth = PCA(n_components=1).fit_transform(feat).reshape(H, W)
 
305
  colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32)
306
  return to_uint8(upsample(colored, h, w))
307
 
 
308
  def vis_kmeans(spatial, h, w, n_clusters=6):
309
  """K-means clustering of spatial features."""
310
  H, W = spatial.shape[:2]
 
320
  seg = palette[labels].astype(np.float32)
321
  return to_uint8(seg)
322
 
 
323
  # ── Zero-shot Segmentation ──────────────────────────────────────────────────
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def vis_custom_semseg(spatial, orig_image, classes, class_embs):
326
  """Zero-shot semantic segmentation with user-defined classes."""
327
  h, w = orig_image.shape[:2]
 
396
  undetected_str = ", ".join(minor_parts + absent)
397
  return overlay_out, mask_img, detected_str, undetected_str
398
 
 
399
  # ── DPT Depth Inference ─────────────────────────────────────────────────────
400
 
 
 
 
 
 
 
401
  def vis_depth_dpt(depth_map, h, w):
402
  """Colour a depth map with the turbo colormap β†’ PIL Image."""
403
  d = depth_map.squeeze()
 
405
  colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32)
406
  return to_uint8(upsample(colored, h, w))
407
 
 
408
  def vis_normals_dpt(normals_map, h, w):
409
  """Map normals from [-1, 1] to [0, 1] and resize to original size."""
410
  # normals_map shape is (3, H, W)
 
413
  n = np.transpose(n, (1, 2, 0)) # (H, W, 3)
414
  return to_uint8(upsample(n, h, w))
415
 
416
+ def vis_segmentation_dpt(seg_map, orig_image):
417
+ """Colour a segmentation map with the ADE20K colormap + legend."""
418
+ h, w = orig_image.shape[:2]
 
419
  logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150)
420
  logits_up = upsample(logits, h, w, mode="bilinear")
421
  pred = logits_up.argmax(axis=-1) # (h, w)
422
+ seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0
423
+
424
+ blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb
425
+ blend_img = Image.fromarray(to_uint8(blend))
426
+
427
+ # Legend: top-10 classes by area
428
+ unique_ids, counts = np.unique(pred, return_counts=True)
429
+ order = np.argsort(-counts)
430
+ unique_ids, counts = unique_ids[order], counts[order]
431
+
432
+ try:
433
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60)
434
+ except OSError:
435
+ font = ImageFont.load_default()
436
+
437
+ n_legend = min(len(unique_ids), 10)
438
+ row_h, swatch_w, pad, legend_w = 80, 60, 12, 450
439
+ legend_h = max(h, n_legend * row_h + pad * 2)
440
+ canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
441
+ canvas.paste(blend_img, (0, 0))
442
+ draw = ImageDraw.Draw(canvas)
443
+
444
+ for i in range(n_legend):
445
+ cid = unique_ids[i]
446
+ color = tuple(ADE20K_PALETTE[cid + 1].tolist())
447
+ name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}"
448
+ y_top = pad + i * row_h
449
+ draw.rectangle([w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], fill=color, outline=(0, 0, 0))
450
+ draw.text((w + pad + swatch_w + 8, y_top + 6), name, fill="black", font=font)
451
 
452
+ return np.array(canvas)
453
 
454
  # ── Gradio callbacks ────────────────────────────────────────────────────────
455
 
 
464
  None, # pca_state
465
  None, None, "", "") # custom outputs
466
 
 
467
  # --- PCA tab callbacks ---
468
 
469
  @spaces.GPU
 
480
  state = {"spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution}
481
  return pca, depth, kmeans, state
482
 
 
483
  @spaces.GPU
484
  def on_recluster(image, resolution, n_clusters, pca_state):
485
  if image is None:
 
497
  h, w = image.shape[:2]
498
  return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state
499
 
 
500
  # --- Zero-shot Segmentation tab callbacks ---
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  @spaces.GPU
503
  def on_zeroseg_custom(image, resolution, class_names_str):
504
  if image is None or not class_names_str or not class_names_str.strip():
 
525
  overlay, mask, detected, undetected = vis_custom_semseg(spatial, image, classes, class_embs)
526
  return overlay, mask, detected, undetected
527
 
 
528
  # --- Depth Feature Visualization tab callbacks ---
529
 
530
  @spaces.GPU
531
  def on_depth_normals_predict(image, dpt_variant, resolution):
532
+ """Run DPT depth and normals prediction."""
533
  if image is None:
534
  return None, None
535
  _load_dpt(dpt_variant)
536
  dev = _device()
537
+ dpt = _dpt["model"].to(dev)
 
 
 
 
 
538
 
539
  h, w = image.shape[:2]
540
  img = Image.fromarray(image).convert("RGB")
541
  tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
542
 
543
+ depth_map = dpt.predict_depth(tensor)
544
+ normals_map = dpt.predict_normals(tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
+ return vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w), vis_normals_dpt(normals_map[0], h, w)
547
 
548
  @spaces.GPU
549
  def on_segmentation_predict(image, dpt_variant, resolution):
550
+ """Run DPT segmentation prediction."""
551
  if image is None:
552
  return None
553
  _load_dpt(dpt_variant)
554
  dev = _device()
555
+ dpt = _dpt["model"].to(dev)
 
 
 
 
556
 
557
  h, w = image.shape[:2]
558
  img = Image.fromarray(image).convert("RGB")
559
  tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
560
 
561
+ seg_map = dpt.predict_segmentation(tensor)
562
+ return vis_segmentation_dpt(seg_map[0], image)
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
  # ── UI ──────────────────────────────────────────────────────────────────────
565
 
 
609
  with gr.Tab("PCA"):
610
  pca_out = gr.Image(label="PCA (3 components β†’ RGB)")
611
  with gr.Tab("PCA (1st component)"):
612
+ depth_out = gr.Image(label="1st PCA component")
613
  with gr.Tab("K-means Clustering"):
614
  n_clusters = gr.Slider(2, 20, value=6, step=1, label="Clusters")
615
  recluster_btn = gr.Button("Re-cluster")