Upload evaluation/annex92_color_heatmaps.py with huggingface_hub
Browse files- evaluation/annex92_color_heatmaps.py +269 -51
evaluation/annex92_color_heatmaps.py
CHANGED
|
@@ -24,6 +24,7 @@ import pandas as pd
|
|
| 24 |
import numpy as np
|
| 25 |
import matplotlib.pyplot as plt
|
| 26 |
import seaborn as sns
|
|
|
|
| 27 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 28 |
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 29 |
from sklearn.model_selection import train_test_split
|
|
@@ -42,6 +43,14 @@ PRIMARY_COLORS = [
|
|
| 42 |
'orange', 'purple', 'brown', 'gray', 'black', 'white'
|
| 43 |
]
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class ColorEncoder:
|
| 46 |
def __init__(self, main_model_path, device='mps'):
|
| 47 |
self.device = torch.device(device)
|
|
@@ -63,6 +72,13 @@ class ColorEncoder:
|
|
| 63 |
|
| 64 |
# Create processor
|
| 65 |
self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Load dataset
|
| 68 |
self._load_dataset()
|
|
@@ -115,8 +131,14 @@ class ColorEncoder:
|
|
| 115 |
|
| 116 |
return dataloader
|
| 117 |
|
| 118 |
-
def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
|
| 119 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
all_embeddings = []
|
| 121 |
all_colors = []
|
| 122 |
|
|
@@ -131,12 +153,20 @@ class ColorEncoder:
|
|
| 131 |
images = images.to(self.device)
|
| 132 |
images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 137 |
|
| 138 |
# Forward pass through main model
|
| 139 |
-
outputs =
|
| 140 |
|
| 141 |
# Extract embeddings based on type
|
| 142 |
if embedding_type == 'text':
|
|
@@ -324,7 +354,19 @@ class ColorEncoder:
|
|
| 324 |
|
| 325 |
return results
|
| 326 |
|
| 327 |
-
def create_color_similarity_heatmap(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
"""
|
| 329 |
Create a heatmap of similarities between encoded colors
|
| 330 |
"""
|
|
@@ -338,34 +380,114 @@ class ColorEncoder:
|
|
| 338 |
if len(color_indices) > 0:
|
| 339 |
color_embeddings = embeddings[color_indices]
|
| 340 |
centroids[color] = np.mean(color_embeddings, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
|
| 343 |
|
| 344 |
for i, color1 in enumerate(unique_colors):
|
| 345 |
for j, color2 in enumerate(unique_colors):
|
| 346 |
if i == j:
|
|
|
|
| 347 |
similarity_matrix[i, j] = 1.0
|
| 348 |
else:
|
| 349 |
similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
|
| 350 |
similarity_matrix[i, j] = similarity
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
plt.figure(figsize=(12, 10))
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
xticklabels=unique_colors,
|
| 360 |
yticklabels=unique_colors,
|
| 361 |
square=True,
|
| 362 |
-
cbar_kws={
|
| 363 |
linewidths=0.5,
|
| 364 |
-
vmin=-0.6,
|
| 365 |
-
vmax=1.0
|
| 366 |
)
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
fontsize=16, fontweight='bold', pad=20)
|
| 370 |
plt.xlabel('Colors', fontsize=14, fontweight='bold')
|
| 371 |
plt.ylabel('Colors', fontsize=14, fontweight='bold')
|
|
@@ -377,6 +499,99 @@ class ColorEncoder:
|
|
| 377 |
print(f"💾 Heatmap saved: {save_path}")
|
| 378 |
|
| 379 |
return plt.gcf(), similarity_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
|
| 382 |
|
|
@@ -518,38 +733,41 @@ if __name__ == "__main__":
|
|
| 518 |
device=device
|
| 519 |
)
|
| 520 |
|
| 521 |
-
# Evaluate primary color classification
|
| 522 |
results = color_encoder.evaluate_color_classification(
|
| 523 |
-
color_encoder.val_df,
|
| 524 |
-
max_samples=10000
|
| 525 |
)
|
| 526 |
-
|
| 527 |
-
if results:
|
| 528 |
-
print(
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
import numpy as np
|
| 25 |
import matplotlib.pyplot as plt
|
| 26 |
import seaborn as sns
|
| 27 |
+
from matplotlib.colors import TwoSlopeNorm
|
| 28 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 29 |
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 30 |
from sklearn.model_selection import train_test_split
|
|
|
|
| 43 |
'orange', 'purple', 'brown', 'gray', 'black', 'white'
|
| 44 |
]
|
| 45 |
|
| 46 |
+
# Fashion-CLIP baseline (used for "baseline" heatmaps).
|
| 47 |
+
BASELINE_MODEL_NAME = "patrickjohncyh/fashion-clip"
|
| 48 |
+
|
| 49 |
+
# Degradation strength for the similarity heatmaps.
|
| 50 |
+
# Higher values mix each color centroid more strongly towards the global centroid,
|
| 51 |
+
# which increases cross-color confusion ("more degraded colors").
|
| 52 |
+
COLOR_CENTROID_DEGRADATION_STRENGTH = 0.30
|
| 53 |
+
|
| 54 |
class ColorEncoder:
|
| 55 |
def __init__(self, main_model_path, device='mps'):
|
| 56 |
self.device = torch.device(device)
|
|
|
|
| 72 |
|
| 73 |
# Create processor
|
| 74 |
self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 75 |
+
|
| 76 |
+
# Load baseline Fashion-CLIP model (for baseline heatmaps).
|
| 77 |
+
print(f"📦 Loading Baseline Fashion-CLIP model from {BASELINE_MODEL_NAME} ...")
|
| 78 |
+
self.baseline_model = CLIPModel_transformers.from_pretrained(BASELINE_MODEL_NAME).to(self.device)
|
| 79 |
+
self.baseline_model.eval()
|
| 80 |
+
self.baseline_processor = CLIPProcessor.from_pretrained(BASELINE_MODEL_NAME)
|
| 81 |
+
print("✅ Baseline Fashion-CLIP model loaded successfully")
|
| 82 |
|
| 83 |
# Load dataset
|
| 84 |
self._load_dataset()
|
|
|
|
| 131 |
|
| 132 |
return dataloader
|
| 133 |
|
| 134 |
+
def extract_color_embeddings(self, dataloader, embedding_type='text', model_kind='main', max_samples=10000):
|
| 135 |
+
"""
|
| 136 |
+
Extract color embeddings (first 16 dimensions) from text or image.
|
| 137 |
+
|
| 138 |
+
model_kind:
|
| 139 |
+
- "main": GAP-CLIP specialized checkpoint (self.main_model)
|
| 140 |
+
- "baseline": Fashion-CLIP baseline (self.baseline_model)
|
| 141 |
+
"""
|
| 142 |
all_embeddings = []
|
| 143 |
all_colors = []
|
| 144 |
|
|
|
|
| 153 |
images = images.to(self.device)
|
| 154 |
images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
|
| 155 |
|
| 156 |
+
# Select model/processor.
|
| 157 |
+
if model_kind == 'baseline':
|
| 158 |
+
model = self.baseline_model
|
| 159 |
+
processor = self.baseline_processor
|
| 160 |
+
else:
|
| 161 |
+
model = self.main_model
|
| 162 |
+
processor = self.processor
|
| 163 |
+
|
| 164 |
+
# Process text inputs.
|
| 165 |
+
text_inputs = processor(text=texts, padding=True, return_tensors="pt")
|
| 166 |
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 167 |
|
| 168 |
# Forward pass through main model
|
| 169 |
+
outputs = model(**text_inputs, pixel_values=images)
|
| 170 |
|
| 171 |
# Extract embeddings based on type
|
| 172 |
if embedding_type == 'text':
|
|
|
|
| 354 |
|
| 355 |
return results
|
| 356 |
|
| 357 |
+
def create_color_similarity_heatmap(
|
| 358 |
+
self,
|
| 359 |
+
embeddings,
|
| 360 |
+
colors,
|
| 361 |
+
embedding_type='text',
|
| 362 |
+
save_path='evaluation/color_similarity_results/color_similarity_heatmap.png',
|
| 363 |
+
centroid_degradation_strength: float = 0.0,
|
| 364 |
+
heatmap_metric: str = "similarity",
|
| 365 |
+
annot: bool = True,
|
| 366 |
+
mask_diagonal: bool = True,
|
| 367 |
+
contrast_percentiles: tuple[float, float] = (5.0, 95.0),
|
| 368 |
+
print_stats: bool = True,
|
| 369 |
+
):
|
| 370 |
"""
|
| 371 |
Create a heatmap of similarities between encoded colors
|
| 372 |
"""
|
|
|
|
| 380 |
if len(color_indices) > 0:
|
| 381 |
color_embeddings = embeddings[color_indices]
|
| 382 |
centroids[color] = np.mean(color_embeddings, axis=0)
|
| 383 |
+
|
| 384 |
+
# Degrade colors by mixing each color centroid toward the global centroid.
|
| 385 |
+
# This increases cross-color similarity and visually "degrades" the color separation.
|
| 386 |
+
centroid_degradation_strength = float(centroid_degradation_strength)
|
| 387 |
+
if centroid_degradation_strength > 0 and len(centroids) > 1:
|
| 388 |
+
global_centroid = np.mean(np.stack(list(centroids.values())), axis=0)
|
| 389 |
+
for c in centroids:
|
| 390 |
+
centroids[c] = (1 - centroid_degradation_strength) * centroids[c] + centroid_degradation_strength * global_centroid
|
| 391 |
|
| 392 |
similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
|
| 393 |
|
| 394 |
for i, color1 in enumerate(unique_colors):
|
| 395 |
for j, color2 in enumerate(unique_colors):
|
| 396 |
if i == j:
|
| 397 |
+
# Cosine between a vector and itself is 1 (centroids are fixed points).
|
| 398 |
similarity_matrix[i, j] = 1.0
|
| 399 |
else:
|
| 400 |
similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
|
| 401 |
similarity_matrix[i, j] = similarity
|
| 402 |
+
|
| 403 |
+
# For visualization: masking diagonal + using off-diagonal auto-contrast
|
| 404 |
+
# makes cross-color differences much more visible.
|
| 405 |
+
n = len(unique_colors)
|
| 406 |
+
mask = np.eye(n, dtype=bool) if mask_diagonal else np.zeros((n, n), dtype=bool)
|
| 407 |
+
|
| 408 |
+
if print_stats:
|
| 409 |
+
off_diag_similarity = similarity_matrix[~mask]
|
| 410 |
+
# Most similar off-diagonal pair = where the model confuses colors most.
|
| 411 |
+
masked_similarity = np.where(mask, -np.inf, similarity_matrix)
|
| 412 |
+
max_i, max_j = np.unravel_index(np.argmax(masked_similarity), similarity_matrix.shape)
|
| 413 |
+
# Least similar off-diagonal pair = most separated colors.
|
| 414 |
+
masked_similarity_min = np.where(mask, np.inf, similarity_matrix)
|
| 415 |
+
min_i, min_j = np.unravel_index(np.argmin(masked_similarity_min), similarity_matrix.shape)
|
| 416 |
+
print(
|
| 417 |
+
f"📈 {embedding_type.upper()} | off-diagonal cosine similarity: "
|
| 418 |
+
f"mean={float(off_diag_similarity.mean()):.3f}, std={float(off_diag_similarity.std()):.3f}"
|
| 419 |
+
)
|
| 420 |
+
print(
|
| 421 |
+
f"📍 {embedding_type.upper()} | most similar pair: "
|
| 422 |
+
f"{unique_colors[max_i]} ↔ {unique_colors[max_j]} = {float(similarity_matrix[max_i, max_j]):.3f}"
|
| 423 |
+
)
|
| 424 |
+
print(
|
| 425 |
+
f"📍 {embedding_type.upper()} | least similar pair: "
|
| 426 |
+
f"{unique_colors[min_i]} ↔ {unique_colors[min_j]} = {float(similarity_matrix[min_i, min_j]):.3f}"
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if heatmap_metric == "similarity":
|
| 430 |
+
plot_matrix = similarity_matrix
|
| 431 |
+
cbar_label = "Cosine Similarity"
|
| 432 |
+
cmap = "RdYlBu_r"
|
| 433 |
+
# Use off-diagonal values to compute contrast.
|
| 434 |
+
off_diag_vals = plot_matrix[~mask]
|
| 435 |
+
elif heatmap_metric == "separation":
|
| 436 |
+
# Higher values => colors are less similar (more separated).
|
| 437 |
+
plot_matrix = 1.0 - similarity_matrix
|
| 438 |
+
cbar_label = "Separation (1 - Cosine Similarity)"
|
| 439 |
+
cmap = "magma"
|
| 440 |
+
off_diag_vals = plot_matrix[~mask]
|
| 441 |
+
else:
|
| 442 |
+
raise ValueError(f"Unsupported heatmap_metric: {heatmap_metric}")
|
| 443 |
+
|
| 444 |
+
# Robust auto-contrast: percentiles avoid single extreme values dominating.
|
| 445 |
+
lo_p, hi_p = contrast_percentiles
|
| 446 |
+
vmin = float(np.percentile(off_diag_vals, lo_p)) if off_diag_vals.size > 0 else None
|
| 447 |
+
vmax = float(np.percentile(off_diag_vals, hi_p)) if off_diag_vals.size > 0 else None
|
| 448 |
+
|
| 449 |
plt.figure(figsize=(12, 10))
|
| 450 |
+
|
| 451 |
+
heatmap_kwargs = dict(
|
| 452 |
+
data=plot_matrix,
|
| 453 |
+
mask=mask,
|
| 454 |
+
annot=annot,
|
| 455 |
+
fmt=".3f" if annot else "",
|
| 456 |
xticklabels=unique_colors,
|
| 457 |
yticklabels=unique_colors,
|
| 458 |
square=True,
|
| 459 |
+
cbar_kws={"label": cbar_label},
|
| 460 |
linewidths=0.5,
|
|
|
|
|
|
|
| 461 |
)
|
| 462 |
+
|
| 463 |
+
if heatmap_metric == "similarity":
|
| 464 |
+
# Diverging scale centered at 0 to emphasize "opposite" directions.
|
| 465 |
+
if vmin is not None and vmax is not None and vmin != vmax:
|
| 466 |
+
# TwoSlopeNorm requires: vmin < vcenter < vmax
|
| 467 |
+
if vmin < 0.0 < vmax:
|
| 468 |
+
vcenter = 0.0
|
| 469 |
+
else:
|
| 470 |
+
# If all values are one-sided (e.g. all positive), pick midpoint.
|
| 471 |
+
vcenter = (vmin + vmax) / 2.0
|
| 472 |
+
|
| 473 |
+
if vmin < vcenter < vmax:
|
| 474 |
+
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
|
| 475 |
+
heatmap_kwargs["norm"] = norm
|
| 476 |
+
else:
|
| 477 |
+
heatmap_kwargs["vmin"] = vmin
|
| 478 |
+
heatmap_kwargs["vmax"] = vmax
|
| 479 |
+
else:
|
| 480 |
+
heatmap_kwargs["vmin"] = vmin
|
| 481 |
+
heatmap_kwargs["vmax"] = vmax
|
| 482 |
+
else:
|
| 483 |
+
# Sequential scale for "separation" (>=0).
|
| 484 |
+
heatmap_kwargs["vmin"] = vmin
|
| 485 |
+
heatmap_kwargs["vmax"] = vmax
|
| 486 |
+
|
| 487 |
+
sns.heatmap(cmap=cmap, **heatmap_kwargs)
|
| 488 |
+
|
| 489 |
+
title_suffix = "separation" if heatmap_metric == "separation" else "similarity"
|
| 490 |
+
plt.title(f"Color {title_suffix} ({embedding_type} embeddings)",
|
| 491 |
fontsize=16, fontweight='bold', pad=20)
|
| 492 |
plt.xlabel('Colors', fontsize=14, fontweight='bold')
|
| 493 |
plt.ylabel('Colors', fontsize=14, fontweight='bold')
|
|
|
|
| 499 |
print(f"💾 Heatmap saved: {save_path}")
|
| 500 |
|
| 501 |
return plt.gcf(), similarity_matrix
|
| 502 |
+
|
| 503 |
+
def generate_similarity_heatmaps(
|
| 504 |
+
self,
|
| 505 |
+
dataloader,
|
| 506 |
+
model_kind: str,
|
| 507 |
+
max_samples: int,
|
| 508 |
+
centroid_degradation_strength: float,
|
| 509 |
+
):
|
| 510 |
+
"""
|
| 511 |
+
Generate and save similarity heatmaps (text + image) for a given model kind.
|
| 512 |
+
"""
|
| 513 |
+
if model_kind not in {'main', 'baseline'}:
|
| 514 |
+
raise ValueError(f"Unsupported model_kind: {model_kind}")
|
| 515 |
+
|
| 516 |
+
os.makedirs('evaluation/color_similarity_results', exist_ok=True)
|
| 517 |
+
|
| 518 |
+
print(f"\n🎨 Generating similarity heatmaps for model_kind={model_kind} "
|
| 519 |
+
f"(degradation_strength={centroid_degradation_strength})...")
|
| 520 |
+
|
| 521 |
+
# Text heatmap.
|
| 522 |
+
text_embeddings, text_colors = self.extract_color_embeddings(
|
| 523 |
+
dataloader,
|
| 524 |
+
embedding_type='text',
|
| 525 |
+
model_kind=model_kind,
|
| 526 |
+
max_samples=max_samples,
|
| 527 |
+
)
|
| 528 |
+
main_or_baseline = 'gap_clip' if model_kind == 'main' else 'fashion_clip_baseline'
|
| 529 |
+
text_save_path = (
|
| 530 |
+
'evaluation/color_similarity_results/text_color_similarity_heatmap.png'
|
| 531 |
+
if model_kind == 'main'
|
| 532 |
+
else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_similarity_heatmap.png'
|
| 533 |
+
)
|
| 534 |
+
text_fig, _ = self.create_color_similarity_heatmap(
|
| 535 |
+
text_embeddings,
|
| 536 |
+
text_colors,
|
| 537 |
+
embedding_type='text',
|
| 538 |
+
save_path=text_save_path,
|
| 539 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 540 |
+
)
|
| 541 |
+
plt.close(text_fig)
|
| 542 |
+
|
| 543 |
+
# Text separation heatmap (more visually sensitive than raw similarity).
|
| 544 |
+
text_sep_save_path = (
|
| 545 |
+
'evaluation/color_similarity_results/text_color_separation_heatmap.png'
|
| 546 |
+
if model_kind == 'main'
|
| 547 |
+
else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_separation_heatmap.png'
|
| 548 |
+
)
|
| 549 |
+
text_sep_fig, _ = self.create_color_similarity_heatmap(
|
| 550 |
+
text_embeddings,
|
| 551 |
+
text_colors,
|
| 552 |
+
embedding_type='text',
|
| 553 |
+
save_path=text_sep_save_path,
|
| 554 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 555 |
+
heatmap_metric="separation",
|
| 556 |
+
)
|
| 557 |
+
plt.close(text_sep_fig)
|
| 558 |
+
|
| 559 |
+
# Image heatmap.
|
| 560 |
+
image_embeddings, image_colors = self.extract_color_embeddings(
|
| 561 |
+
dataloader,
|
| 562 |
+
embedding_type='image',
|
| 563 |
+
model_kind=model_kind,
|
| 564 |
+
max_samples=max_samples,
|
| 565 |
+
)
|
| 566 |
+
image_save_path = (
|
| 567 |
+
'evaluation/color_similarity_results/image_color_similarity_heatmap.png'
|
| 568 |
+
if model_kind == 'main'
|
| 569 |
+
else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_similarity_heatmap.png'
|
| 570 |
+
)
|
| 571 |
+
image_fig, _ = self.create_color_similarity_heatmap(
|
| 572 |
+
image_embeddings,
|
| 573 |
+
image_colors,
|
| 574 |
+
embedding_type='image',
|
| 575 |
+
save_path=image_save_path,
|
| 576 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 577 |
+
)
|
| 578 |
+
plt.close(image_fig)
|
| 579 |
+
|
| 580 |
+
# Image separation heatmap.
|
| 581 |
+
image_sep_save_path = (
|
| 582 |
+
'evaluation/color_similarity_results/image_color_separation_heatmap.png'
|
| 583 |
+
if model_kind == 'main'
|
| 584 |
+
else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_separation_heatmap.png'
|
| 585 |
+
)
|
| 586 |
+
image_sep_fig, _ = self.create_color_similarity_heatmap(
|
| 587 |
+
image_embeddings,
|
| 588 |
+
image_colors,
|
| 589 |
+
embedding_type='image',
|
| 590 |
+
save_path=image_sep_save_path,
|
| 591 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 592 |
+
heatmap_metric="separation",
|
| 593 |
+
)
|
| 594 |
+
plt.close(image_sep_fig)
|
| 595 |
|
| 596 |
|
| 597 |
|
|
|
|
| 733 |
device=device
|
| 734 |
)
|
| 735 |
|
| 736 |
+
# Evaluate primary color classification for the main model (keeps previous behavior).
|
| 737 |
results = color_encoder.evaluate_color_classification(
|
| 738 |
+
color_encoder.val_df,
|
| 739 |
+
max_samples=10000,
|
| 740 |
)
|
| 741 |
+
|
| 742 |
+
if not results:
|
| 743 |
+
print("❌ No results generated - check if primary colors exist in dataset")
|
| 744 |
+
raise SystemExit(1)
|
| 745 |
+
|
| 746 |
+
print(f"\n✅ Primary color encoding and confusion matrix generation completed!")
|
| 747 |
+
print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory")
|
| 748 |
+
print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%")
|
| 749 |
+
print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%")
|
| 750 |
+
|
| 751 |
+
# Heatmaps with additional centroid degradation (main model + baseline).
|
| 752 |
+
dataloader = color_encoder.create_dataloader(color_encoder.val_df, batch_size=8)
|
| 753 |
+
max_samples = 10000
|
| 754 |
+
centroid_degradation_strength = COLOR_CENTROID_DEGRADATION_STRENGTH
|
| 755 |
+
|
| 756 |
+
# Your model (GAP-CLIP main checkpoint): overwrites the existing heatmap filenames.
|
| 757 |
+
color_encoder.generate_similarity_heatmaps(
|
| 758 |
+
dataloader=dataloader,
|
| 759 |
+
model_kind='main',
|
| 760 |
+
max_samples=max_samples,
|
| 761 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Baseline Fashion-CLIP: saved as fashion_clip_baseline_* heatmaps.
|
| 765 |
+
color_encoder.generate_similarity_heatmaps(
|
| 766 |
+
dataloader=dataloader,
|
| 767 |
+
model_kind='baseline',
|
| 768 |
+
max_samples=max_samples,
|
| 769 |
+
centroid_degradation_strength=centroid_degradation_strength,
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
print("\n✅ Color similarity analysis completed!")
|
| 773 |
+
print("📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory")
|