Leacb4 commited on
Commit
8a7c966
Β·
verified Β·
1 Parent(s): 3e2b688

Upload evaluation/test_color_across_hierarchies.py with huggingface_hub

Browse files
evaluation/test_color_across_hierarchies.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Color retrieval accuracy across different hierarchies β€” Baseline vs GAP-CLIP.
4
+
5
+ For each color, pairs it with every hierarchy category and measures how well
6
+ each model classifies the correct color and hierarchy via nearest-neighbor.
7
+
8
+ Three classification strategies are compared:
9
+ 1. Naive β€” bare label words ("dress", "shirt", ...) as label embeddings
10
+ 2. Ensembled β€” average of multiple prompt templates per label (standard CLIP trick)
11
+ 3. Structured β€” (GAP-CLIP only) color-marginalized label centroids in the
12
+ hierarchy subspace. For each hierarchy, embed "{c} {h}" for
13
+ ALL colors, extract the 64D hierarchy slice, and average.
14
+ This builds color-agnostic hierarchy prototypes that exploit
15
+ GAP-CLIP's learned subspace decomposition.
16
+
17
+ Run:
18
+ python3 -m evaluation.test_color_across_hierarchies # single color (red)
19
+ python3 -m evaluation.test_color_across_hierarchies --color blue
20
+ python3 -m evaluation.test_color_across_hierarchies --all-colors # full sweep + graph
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import os
27
+ import sys
28
+
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ from pathlib import Path
32
+ from typing import Dict, List, Tuple
33
+
34
+ import matplotlib.pyplot as plt
35
+ import matplotlib.ticker as mtick
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn.functional as F
39
+ from transformers import CLIPModel as CLIPModelTransformers, CLIPProcessor
40
+
41
+ _PROJECT_ROOT = Path(__file__).resolve().parents[1]
42
+ if str(_PROJECT_ROOT) not in sys.path:
43
+ sys.path.insert(0, str(_PROJECT_ROOT))
44
+
45
+ import config
46
+ from evaluation.utils.model_loader import (
47
+ load_baseline_fashion_clip,
48
+ load_gap_clip,
49
+ get_text_embedding,
50
+ get_text_embeddings_batch,
51
+ )
52
+
53
+ # ── Constants ────────────────────────────────────────────────────────────────
54
+
55
+ COLORS = [
56
+ "beige", "black", "blue", "brown", "green",
57
+ "orange", "pink", "purple", "red", "white", "yellow",
58
+ ]
59
+
60
+ HIERARCHIES = [
61
+ "dress", "shirt", "pants", "skirt", "jacket",
62
+ "coat", "jeans", "sweater", "shorts", "top",
63
+ ]
64
+
65
+ # Templates used to build query sentences
66
+ QUERY_TEMPLATES = [
67
+ "{color} {hierarchy}",
68
+ "a {color} {hierarchy}",
69
+ "{color} {hierarchy} for women",
70
+ "casual {color} {hierarchy}",
71
+ "elegant {color} {hierarchy}",
72
+ ]
73
+
74
+ # Templates for label ensembling (strategy 2)
75
+ LABEL_TEMPLATES = [
76
+ "{}",
77
+ "a {}",
78
+ "a photo of a {}",
79
+ "a fashion {}",
80
+ "a piece of clothing: {}",
81
+ ]
82
+
83
+ FIGURES_DIR = _PROJECT_ROOT / "figures"
84
+
85
+ # ── Helpers ──────────────────────────────────────────────────────────────────
86
+
87
+
88
+ def classify_nearest(
89
+ query_emb: torch.Tensor,
90
+ label_embs: torch.Tensor,
91
+ labels: List[str],
92
+ ) -> Tuple[str, float]:
93
+ sims = F.cosine_similarity(query_emb.unsqueeze(0), label_embs, dim=1)
94
+ idx = sims.argmax().item()
95
+ return labels[idx], sims[idx].item()
96
+
97
+
98
+ # ── Label embedding builders ─────────────────────────────────────────────────
99
+
100
+
101
+ def build_naive_labels(model, processor, device, labels):
102
+ """Strategy 1: bare words."""
103
+ return get_text_embeddings_batch(model, processor, device, labels)
104
+
105
+
106
+ def build_ensembled_labels(model, processor, device, labels):
107
+ """Strategy 2: average of LABEL_TEMPLATES per label."""
108
+ out = []
109
+ for label in labels:
110
+ prompts = [t.format(label) for t in LABEL_TEMPLATES]
111
+ embs = get_text_embeddings_batch(model, processor, device, prompts)
112
+ out.append(F.normalize(embs.mean(dim=0), dim=-1))
113
+ return torch.stack(out)
114
+
115
+
116
+ def build_color_marginalized_labels(model, processor, device, hier_start, hier_end):
117
+ """Strategy 3 (GAP-CLIP only): for each hierarchy, embed '{c} {h}' for all
118
+ colors, extract the hierarchy subspace, average β†’ color-agnostic centroid."""
119
+ out = []
120
+ for h in HIERARCHIES:
121
+ all_embs = []
122
+ for c in COLORS:
123
+ for tmpl in QUERY_TEMPLATES:
124
+ query = tmpl.format(color=c, hierarchy=h)
125
+ emb = get_text_embedding(model, processor, device, query)
126
+ all_embs.append(emb[hier_start:hier_end])
127
+ stacked = torch.stack(all_embs)
128
+ centroid = F.normalize(stacked.mean(dim=0), dim=-1)
129
+ out.append(centroid)
130
+ return torch.stack(out)
131
+
132
+
133
+ # ── Per-model evaluation ─────────────────────────────────────────────────────
134
+
135
+
136
+ def evaluate_model(
137
+ model, processor, device, target_color, model_name,
138
+ color_dim=0, hier_start=0, hier_end=0,
139
+ ) -> Dict:
140
+ is_gap_clip = color_dim > 0
141
+
142
+ # Build all label embedding variants
143
+ naive_color_labels = build_naive_labels(model, processor, device, COLORS)
144
+ naive_hier_labels = build_naive_labels(model, processor, device, HIERARCHIES)
145
+ ens_color_labels = build_ensembled_labels(model, processor, device, COLORS)
146
+ ens_hier_labels = build_ensembled_labels(model, processor, device, HIERARCHIES)
147
+
148
+ if is_gap_clip:
149
+ naive_color_sub = F.normalize(naive_color_labels[:, :color_dim], dim=-1)
150
+ naive_hier_sub = F.normalize(naive_hier_labels[:, hier_start:hier_end], dim=-1)
151
+ ens_hier_sub = F.normalize(ens_hier_labels[:, hier_start:hier_end], dim=-1)
152
+ marg_hier_sub = build_color_marginalized_labels(
153
+ model, processor, device, hier_start, hier_end
154
+ )
155
+
156
+ rows: List[Dict] = []
157
+
158
+ for hierarchy in HIERARCHIES:
159
+ for template in QUERY_TEMPLATES:
160
+ query = template.format(color=target_color, hierarchy=hierarchy)
161
+ emb = get_text_embedding(model, processor, device, query)
162
+
163
+ # ── Strategy 1: naive 512D ──
164
+ pc_naive, _ = classify_nearest(emb, naive_color_labels, COLORS)
165
+ ph_naive, _ = classify_nearest(emb, naive_hier_labels, HIERARCHIES)
166
+
167
+ # ── Strategy 2: ensembled 512D ──
168
+ pc_ens, _ = classify_nearest(emb, ens_color_labels, COLORS)
169
+ ph_ens, _ = classify_nearest(emb, ens_hier_labels, HIERARCHIES)
170
+
171
+ row = {
172
+ "query": query,
173
+ "true_color": target_color,
174
+ "true_hierarchy": hierarchy,
175
+ "color_naive": pc_naive == target_color,
176
+ "hier_naive": ph_naive == hierarchy,
177
+ "color_ens": pc_ens == target_color,
178
+ "hier_ens": ph_ens == hierarchy,
179
+ }
180
+
181
+ if is_gap_clip:
182
+ # ── Naive subspace ──
183
+ c_sub = F.normalize(emb[:color_dim].unsqueeze(0), dim=-1).squeeze(0)
184
+ h_sub = F.normalize(emb[hier_start:hier_end].unsqueeze(0), dim=-1).squeeze(0)
185
+
186
+ pc_sub, _ = classify_nearest(c_sub, naive_color_sub, COLORS)
187
+ ph_sub, _ = classify_nearest(h_sub, naive_hier_sub, HIERARCHIES)
188
+
189
+ # ── Ensembled subspace ──
190
+ ph_ens_sub, _ = classify_nearest(h_sub, ens_hier_sub, HIERARCHIES)
191
+
192
+ # ── Strategy 3: color-marginalized subspace ──
193
+ ph_marg, _ = classify_nearest(h_sub, marg_hier_sub, HIERARCHIES)
194
+
195
+ row.update({
196
+ "color_sub_naive": pc_sub == target_color,
197
+ "hier_sub_naive": ph_sub == hierarchy,
198
+ "hier_sub_ens": ph_ens_sub == hierarchy,
199
+ "hier_sub_marg": ph_marg == hierarchy,
200
+ })
201
+
202
+ rows.append(row)
203
+
204
+ # Aggregate
205
+ n = len(rows)
206
+ summary = {
207
+ "model": model_name,
208
+ "target_color": target_color,
209
+ "n": n,
210
+ "color_naive": sum(r["color_naive"] for r in rows) / n,
211
+ "hier_naive": sum(r["hier_naive"] for r in rows) / n,
212
+ "color_ens": sum(r["color_ens"] for r in rows) / n,
213
+ "hier_ens": sum(r["hier_ens"] for r in rows) / n,
214
+ }
215
+ if is_gap_clip:
216
+ summary.update({
217
+ "color_sub_naive": sum(r["color_sub_naive"] for r in rows) / n,
218
+ "hier_sub_naive": sum(r["hier_sub_naive"] for r in rows) / n,
219
+ "hier_sub_ens": sum(r["hier_sub_ens"] for r in rows) / n,
220
+ "hier_sub_marg": sum(r["hier_sub_marg"] for r in rows) / n,
221
+ })
222
+ return {"summary": summary, "rows": rows}
223
+
224
+
225
+ # ── Pretty printing ──────────────────────────────────────────────────────────
226
+
227
+
228
+ def print_single_color(bl, gc):
229
+ bs, gs = bl["summary"], gc["summary"]
230
+ color = bs["target_color"]
231
+
232
+ print("\n" + "=" * 92)
233
+ print(f" COLOR ACROSS HIERARCHIES β€” target: \"{color}\"")
234
+ print(f" {bs['n']} queries ({len(HIERARCHIES)} hierarchies x {len(QUERY_TEMPLATES)} templates)")
235
+ print("=" * 92)
236
+
237
+ print(f"\n {'Strategy':<40} {'Baseline':<14} {'GAP-CLIP':<14}")
238
+ print(f" {'-' * 68}")
239
+
240
+ def row(label, bk, gk):
241
+ print(f" {label:<40} {bs[bk]:>8.1%}{'':6} {gs[gk]:>8.1%}")
242
+
243
+ row("Color acc β€” naive (512D)", "color_naive", "color_naive")
244
+ row("Color acc β€” ensembled (512D)", "color_ens", "color_ens")
245
+ print(f" {'Color acc β€” subspace (16D)':<40} {'N/A':>8}{'':6} {gs['color_sub_naive']:>8.1%}")
246
+ print()
247
+ row("Hier acc β€” naive (512D)", "hier_naive", "hier_naive")
248
+ row("Hier acc β€” ensembled (512D)", "hier_ens", "hier_ens")
249
+ print(f" {'Hier acc β€” subspace naive (64D)':<40} {'N/A':>8}{'':6} {gs['hier_sub_naive']:>8.1%}")
250
+ print(f" {'Hier acc β€” subspace ensembled (64D)':<40} {'N/A':>8}{'':6} {gs['hier_sub_ens']:>8.1%}")
251
+ print(f" {'Hier acc β€” subspace marginalized (64D)':<40} {'N/A':>8}{'':6} {gs['hier_sub_marg']:>8.1%}")
252
+
253
+ # Per-hierarchy breakdown for the best strategies
254
+ print(f"\n Per-hierarchy (best strategies):")
255
+ print(f" {'Hierarchy':<12} {'BL ens(512)':<14} {'GC ens(512)':<14} {'GC marg(64)':<14}")
256
+ print(f" {'-' * 54}")
257
+ for h in HIERARCHIES:
258
+ bl_rows = [r for r in bl["rows"] if r["true_hierarchy"] == h]
259
+ gc_rows = [r for r in gc["rows"] if r["true_hierarchy"] == h]
260
+ nh = len(bl_rows)
261
+ b = sum(r["hier_ens"] for r in bl_rows) / nh
262
+ g512 = sum(r["hier_ens"] for r in gc_rows) / nh
263
+ g64 = sum(r["hier_sub_marg"] for r in gc_rows) / nh
264
+ print(f" {h:<12} {b:>8.1%}{'':6} {g512:>8.1%}{'':6} {g64:>8.1%}")
265
+
266
+ print("=" * 92)
267
+
268
+
269
+ # ── Graph ────────────────────────────────────────────────────────────────────
270
+
271
+
272
+ def plot_all_colors_graph(all_bl, all_gc):
273
+ """Create a publication-quality comparison chart."""
274
+ FIGURES_DIR.mkdir(exist_ok=True)
275
+
276
+ bl_color_naive = [all_bl[c]["color_naive"] for c in COLORS]
277
+ bl_hier_naive = [all_bl[c]["hier_naive"] for c in COLORS]
278
+ bl_hier_ens = [all_bl[c]["hier_ens"] for c in COLORS]
279
+
280
+ gc_color_naive = [all_gc[c]["color_naive"] for c in COLORS]
281
+ gc_color_sub = [all_gc[c]["color_sub_naive"] for c in COLORS]
282
+ gc_hier_naive = [all_gc[c]["hier_naive"] for c in COLORS]
283
+ gc_hier_ens = [all_gc[c]["hier_ens"] for c in COLORS]
284
+ gc_hier_marg = [all_gc[c]["hier_sub_marg"] for c in COLORS]
285
+
286
+ # Use a clean style
287
+ plt.rcParams.update({
288
+ "font.family": "sans-serif",
289
+ "axes.facecolor": "#FAFAFA",
290
+ "figure.facecolor": "white",
291
+ })
292
+
293
+ fig = plt.figure(figsize=(20, 14))
294
+ gs = fig.add_gridspec(2, 2, hspace=0.42, wspace=0.28,
295
+ height_ratios=[1, 1.1])
296
+
297
+ x = np.arange(len(COLORS))
298
+ color_labels = [c.capitalize() for c in COLORS]
299
+
300
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
301
+ # TOP-LEFT: Color accuracy (zoomed to 85-102%)
302
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
303
+ ax1 = fig.add_subplot(gs[0, 0])
304
+ bar_w = 0.22
305
+ b1 = ax1.bar(x - bar_w, bl_color_naive, bar_w, label="Baseline (512D)",
306
+ color="#5B9BD5", edgecolor="white", linewidth=0.6, zorder=3)
307
+ b2 = ax1.bar(x, gc_color_naive, bar_w, label="GAP-CLIP (512D)",
308
+ color="#ED7D31", edgecolor="white", linewidth=0.6, zorder=3)
309
+ b3 = ax1.bar(x + bar_w, gc_color_sub, bar_w, label="GAP-CLIP 16D subspace",
310
+ color="#70AD47", edgecolor="white", linewidth=0.6, zorder=3)
311
+
312
+ ax1.set_title("A. Color Classification Accuracy", fontsize=14, fontweight="bold",
313
+ loc="left", pad=12)
314
+ ax1.set_xticks(x)
315
+ ax1.set_xticklabels(color_labels, rotation=35, ha="right", fontsize=10)
316
+ ax1.set_ylabel("Accuracy", fontsize=11)
317
+ ax1.set_ylim(0.85, 1.04)
318
+ ax1.yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))
319
+ ax1.legend(fontsize=9, framealpha=0.95, loc="lower left")
320
+ ax1.grid(axis="y", alpha=0.25, linestyle="--", zorder=0)
321
+ ax1.spines["top"].set_visible(False)
322
+ ax1.spines["right"].set_visible(False)
323
+
324
+ # Annotate means
325
+ for vals, clr, lbl, yoff in [
326
+ (bl_color_naive, "#5B9BD5", "BL", 0.006),
327
+ (gc_color_sub, "#70AD47", "GC-16D", -0.012),
328
+ ]:
329
+ m = np.mean(vals)
330
+ ax1.axhline(m, color=clr, linestyle=":", alpha=0.5, linewidth=1.0, zorder=1)
331
+ ax1.text(len(COLORS) - 0.3, m + yoff, f"{lbl} mean: {m:.1%}",
332
+ fontsize=8, color=clr, ha="right", fontstyle="italic")
333
+
334
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
335
+ # TOP-RIGHT: Hierarchy accuracy β€” zoomed to 70-102%
336
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
337
+ ax2 = fig.add_subplot(gs[0, 1])
338
+ bar_w = 0.14
339
+ offsets = np.array([-2, -1, 0, 1, 2])
340
+
341
+ bars_cfg = [
342
+ (bl_hier_naive, "Baseline naive (512D)", "#93C4ED"),
343
+ (bl_hier_ens, "Baseline ensembled (512D)", "#2E75B6"),
344
+ (gc_hier_naive, "GAP-CLIP naive (512D)", "#F4B183"),
345
+ (gc_hier_ens, "GAP-CLIP ensembled (512D)", "#C55A11"),
346
+ (gc_hier_marg, "GAP-CLIP structured (64D)", "#70AD47"),
347
+ ]
348
+
349
+ for i, (data, label, color) in enumerate(bars_cfg):
350
+ ax2.bar(x + offsets[i] * bar_w, data, bar_w, label=label, color=color,
351
+ edgecolor="white", linewidth=0.6, zorder=3)
352
+
353
+ ax2.set_title("B. Hierarchy Classification Accuracy", fontsize=14,
354
+ fontweight="bold", loc="left", pad=12)
355
+ ax2.set_xticks(x)
356
+ ax2.set_xticklabels(color_labels, rotation=35, ha="right", fontsize=10)
357
+ ax2.set_ylabel("Accuracy", fontsize=11)
358
+ ax2.set_ylim(0.70, 1.05)
359
+ ax2.yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))
360
+ ax2.legend(fontsize=8.5, framealpha=0.95, loc="lower left", ncol=1)
361
+ ax2.grid(axis="y", alpha=0.25, linestyle="--", zorder=0)
362
+ ax2.spines["top"].set_visible(False)
363
+ ax2.spines["right"].set_visible(False)
364
+
365
+ bl_hm = np.mean(bl_hier_ens)
366
+ gc_hm = np.mean(gc_hier_marg)
367
+ ax2.axhline(bl_hm, color="#2E75B6", linestyle="--", alpha=0.6, linewidth=1.2, zorder=1)
368
+ ax2.axhline(gc_hm, color="#70AD47", linestyle="--", alpha=0.6, linewidth=1.2, zorder=1)
369
+ ax2.text(len(COLORS) - 0.3, bl_hm - 0.016, f"BL-ens mean: {bl_hm:.1%}",
370
+ fontsize=8.5, color="#2E75B6", ha="right", fontweight="bold")
371
+ ax2.text(len(COLORS) - 0.3, gc_hm + 0.006, f"GC-struct mean: {gc_hm:.1%}",
372
+ fontsize=8.5, color="#70AD47", ha="right", fontweight="bold")
373
+
374
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
375
+ # BOTTOM: Mean accuracy summary bar chart
376
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
377
+ ax3 = fig.add_subplot(gs[1, :])
378
+
379
+ metrics = [
380
+ ("Color\n(Naive 512D)", np.mean(bl_color_naive), np.mean(gc_color_naive)),
381
+ ("Color\n(16D Subspace)", None, np.mean(gc_color_sub)),
382
+ ("Hierarchy\n(Naive 512D)", np.mean(bl_hier_naive), np.mean(gc_hier_naive)),
383
+ ("Hierarchy\n(Ens. 512D)", np.mean(bl_hier_ens), np.mean(gc_hier_ens)),
384
+ ("Hierarchy\n(Structured 64D)", None, np.mean(gc_hier_marg)),
385
+ ]
386
+
387
+ xm = np.arange(len(metrics))
388
+ bar_w = 0.30
389
+ bl_vals = [m[1] for m in metrics]
390
+ gc_vals = [m[2] for m in metrics]
391
+
392
+ for i, (label, bv, gv) in enumerate(metrics):
393
+ if bv is not None:
394
+ bar_bl = ax3.bar(i - bar_w / 2, bv, bar_w, color="#2E75B6",
395
+ edgecolor="white", linewidth=0.8, zorder=3,
396
+ label="Baseline" if i == 0 else "")
397
+ ax3.text(i - bar_w / 2, bv + 0.008, f"{bv:.1%}", ha="center",
398
+ fontsize=10, fontweight="bold", color="#2E75B6", zorder=4)
399
+ bar_gc = ax3.bar(i + (bar_w / 2 if bv is not None else 0), gv, bar_w,
400
+ color="#70AD47", edgecolor="white", linewidth=0.8, zorder=3,
401
+ label="GAP-CLIP" if i == 0 else "")
402
+ xpos = i + (bar_w / 2 if bv is not None else 0)
403
+ ax3.text(xpos, gv + 0.008, f"{gv:.1%}", ha="center",
404
+ fontsize=10, fontweight="bold", color="#70AD47", zorder=4)
405
+
406
+ # Delta annotation for hierarchy metrics where both exist
407
+ if bv is not None and "Hierarchy" in label:
408
+ delta = gv - bv
409
+ sign = "+" if delta >= 0 else ""
410
+ clr = "#70AD47" if delta > 0 else "#C00000"
411
+ ax3.annotate(
412
+ f"{sign}{delta:.1%}",
413
+ xy=(i + bar_w / 2, gv),
414
+ xytext=(i + bar_w / 2 + 0.25, gv + 0.03),
415
+ fontsize=9, fontweight="bold", color=clr,
416
+ arrowprops=dict(arrowstyle="->", color=clr, lw=1.2),
417
+ zorder=5,
418
+ )
419
+
420
+ ax3.set_title("C. Mean Accuracy Summary (across all 11 colors)",
421
+ fontsize=14, fontweight="bold", loc="left", pad=12)
422
+ ax3.set_xticks(xm)
423
+ ax3.set_xticklabels([m[0] for m in metrics], fontsize=10.5)
424
+ ax3.set_ylabel("Mean Accuracy", fontsize=11)
425
+ ax3.set_ylim(0.75, 1.08)
426
+ ax3.yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))
427
+ ax3.legend(fontsize=11, framealpha=0.95, loc="lower left")
428
+ ax3.grid(axis="y", alpha=0.25, linestyle="--", zorder=0)
429
+ ax3.spines["top"].set_visible(False)
430
+ ax3.spines["right"].set_visible(False)
431
+
432
+ # Global title
433
+ fig.suptitle(
434
+ "Color Retrieval Test β€” Baseline (Fashion-CLIP) vs GAP-CLIP\n"
435
+ f"{len(COLORS)} colors x {len(HIERARCHIES)} hierarchies x "
436
+ f"{len(QUERY_TEMPLATES)} templates = {len(COLORS)*len(HIERARCHIES)*len(QUERY_TEMPLATES)} queries per model",
437
+ fontsize=16, fontweight="bold", y=1.01,
438
+ )
439
+
440
+ out_path = FIGURES_DIR / "color_across_hierarchies.png"
441
+ fig.savefig(out_path, dpi=200, bbox_inches="tight", facecolor="white")
442
+ plt.close(fig)
443
+ print(f"\nFigure saved -> {out_path}")
444
+ return out_path
445
+
446
+
447
+ # ── All-colors sweep ─────���───────────────────────────────────────────────────
448
+
449
+
450
+ def run_all_colors(device):
451
+ print("Loading models...")
452
+ bl_model, bl_proc = load_baseline_fashion_clip(device)
453
+ gc_model, gc_proc = load_gap_clip(config.main_model_path, device)
454
+
455
+ all_bl, all_gc = {}, {}
456
+
457
+ for color in COLORS:
458
+ print(f"\n--- Evaluating: {color} ---")
459
+ bl = evaluate_model(bl_model, bl_proc, device, color, "Baseline")
460
+ gc = evaluate_model(
461
+ gc_model, gc_proc, device, color, "GAP-CLIP",
462
+ color_dim=config.color_emb_dim,
463
+ hier_start=config.color_emb_dim,
464
+ hier_end=config.color_emb_dim + config.hierarchy_emb_dim,
465
+ )
466
+ all_bl[color] = bl["summary"]
467
+ all_gc[color] = gc["summary"]
468
+
469
+ # ── Summary table ──
470
+ print("\n" + "=" * 115)
471
+ print(" ALL-COLORS SUMMARY")
472
+ print("=" * 115)
473
+
474
+ print(f"\n {'':12}"
475
+ f"{'--- COLOR ACC ---':^36}"
476
+ f"{'--- HIERARCHY ACC ---':^60}")
477
+ print(f" {'Color':<12}"
478
+ f"{'BL(512)':>10} {'GC(512)':>10} {'GC(16D)':>10} "
479
+ f"{'BL naive':>10} {'BL ens':>10} {'GC naive':>10} {'GC ens':>10} {'GC struct':>10}")
480
+ print(f" {'-' * 105}")
481
+
482
+ totals = {k: 0.0 for k in [
483
+ "bl_cn", "gc_cn", "gc_cs",
484
+ "bl_hn", "bl_he", "gc_hn", "gc_he", "gc_hm",
485
+ ]}
486
+
487
+ for color in COLORS:
488
+ b, g = all_bl[color], all_gc[color]
489
+ totals["bl_cn"] += b["color_naive"]
490
+ totals["gc_cn"] += g["color_naive"]
491
+ totals["gc_cs"] += g["color_sub_naive"]
492
+ totals["bl_hn"] += b["hier_naive"]
493
+ totals["bl_he"] += b["hier_ens"]
494
+ totals["gc_hn"] += g["hier_naive"]
495
+ totals["gc_he"] += g["hier_ens"]
496
+ totals["gc_hm"] += g["hier_sub_marg"]
497
+
498
+ print(
499
+ f" {color:<12}"
500
+ f"{b['color_naive']:>9.1%} {g['color_naive']:>10.1%} {g['color_sub_naive']:>10.1%} "
501
+ f"{b['hier_naive']:>9.1%} {b['hier_ens']:>10.1%} {g['hier_naive']:>10.1%} "
502
+ f"{g['hier_ens']:>10.1%} {g['hier_sub_marg']:>10.1%}"
503
+ )
504
+
505
+ n = len(COLORS)
506
+ print(f" {'-' * 105}")
507
+ print(
508
+ f" {'MEAN':<12}"
509
+ f"{totals['bl_cn']/n:>9.1%} {totals['gc_cn']/n:>10.1%} {totals['gc_cs']/n:>10.1%} "
510
+ f"{totals['bl_hn']/n:>9.1%} {totals['bl_he']/n:>10.1%} {totals['gc_hn']/n:>10.1%} "
511
+ f"{totals['gc_he']/n:>10.1%} {totals['gc_hm']/n:>10.1%}"
512
+ )
513
+ print("=" * 115)
514
+
515
+ # ── Graph ──
516
+ plot_all_colors_graph(all_bl, all_gc)
517
+
518
+
519
+ # ── Main ─────────────────────────────────────────────────────────────────────
520
+
521
+
522
+ def main():
523
+ parser = argparse.ArgumentParser(
524
+ description="Color retrieval accuracy across hierarchies β€” Baseline vs GAP-CLIP"
525
+ )
526
+ parser.add_argument(
527
+ "--color", type=str, default="red",
528
+ help=f"Target color (default: red). Choices: {', '.join(COLORS)}",
529
+ )
530
+ parser.add_argument(
531
+ "--all-colors", action="store_true",
532
+ help="Run for all 11 colors and produce a comparison graph",
533
+ )
534
+ args = parser.parse_args()
535
+
536
+ device = config.device
537
+ print(f"Device: {device}")
538
+
539
+ if args.all_colors:
540
+ run_all_colors(device)
541
+ return
542
+
543
+ target_color = args.color.lower()
544
+ if target_color not in COLORS:
545
+ print(f"Error: '{target_color}' not in {COLORS}")
546
+ sys.exit(1)
547
+
548
+ print("Loading Baseline (Fashion-CLIP)...")
549
+ bl_model, bl_proc = load_baseline_fashion_clip(device)
550
+ print("Loading GAP-CLIP...")
551
+ gc_model, gc_proc = load_gap_clip(config.main_model_path, device)
552
+
553
+ print(f"\nEvaluating \"{target_color}\" across {len(HIERARCHIES)} hierarchies...\n")
554
+
555
+ bl = evaluate_model(bl_model, bl_proc, device, target_color, "Baseline")
556
+ gc = evaluate_model(
557
+ gc_model, gc_proc, device, target_color, "GAP-CLIP",
558
+ color_dim=config.color_emb_dim,
559
+ hier_start=config.color_emb_dim,
560
+ hier_end=config.color_emb_dim + config.hierarchy_emb_dim,
561
+ )
562
+
563
+ print_single_color(bl, gc)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ main()