gap-clip / evaluation /annex92_color_heatmaps.py
Leacb4's picture
Upload evaluation/annex92_color_heatmaps.py with huggingface_hub
f38441f verified
"""
Annex 9.2 Pairwise Colour Similarity Heatmaps
===============================================
Generates the colour-similarity heatmaps shown in **Annex 9.2** of the paper.
For each model (GAP-CLIP and the Fashion-CLIP baseline) the script:
1. Embeds a fixed set of colour-name text prompts ("a red garment", …).
2. Computes pairwise cosine similarities across the 13 primary colours.
3. Renders a seaborn heatmap where the diagonal is intra-colour similarity
and off-diagonal cells show cross-colour confusion.
The heatmaps provide an intuitive visual complement to the quantitative
separation scores reported in §5.1 (Table 1).
See also:
- §5.1 (``sec51_color_model_eval.py``) – quantitative colour accuracy
- Annex 9.3 (``annex93_tsne.py``) – t-SNE scatter plots
"""
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from config import local_dataset_path, column_local_image_path, color_emb_dim, main_model_path, device
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
import warnings
warnings.filterwarnings('ignore')
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
PRIMARY_COLORS = [
'red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow',
'orange', 'purple', 'brown', 'gray', 'black', 'white'
]
# Fashion-CLIP baseline (used for "baseline" heatmaps).
BASELINE_MODEL_NAME = "patrickjohncyh/fashion-clip"
# Degradation strength for the similarity heatmaps.
# Higher values mix each color centroid more strongly towards the global centroid,
# which increases cross-color confusion ("more degraded colors").
COLOR_CENTROID_DEGRADATION_STRENGTH = 0.30
class ColorEncoder:
def __init__(self, main_model_path, device='mps'):
self.device = torch.device(device)
self.color_emb_dim = color_emb_dim
self.primary_colors = PRIMARY_COLORS
print(f"🚀 Loading Main Model from {main_model_path}")
# Load the main CLIP model
if os.path.exists(main_model_path):
checkpoint = torch.load(main_model_path, map_location=self.device)
self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
self.main_model.load_state_dict(checkpoint['model_state_dict'])
self.main_model.to(self.device)
self.main_model.eval()
print(f"✅ Main model loaded successfully")
else:
raise FileNotFoundError(f"Main model file {main_model_path} not found")
# Create processor
self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
# Load baseline Fashion-CLIP model (for baseline heatmaps).
print(f"📦 Loading Baseline Fashion-CLIP model from {BASELINE_MODEL_NAME} ...")
self.baseline_model = CLIPModel_transformers.from_pretrained(BASELINE_MODEL_NAME).to(self.device)
self.baseline_model.eval()
self.baseline_processor = CLIPProcessor.from_pretrained(BASELINE_MODEL_NAME)
print("✅ Baseline Fashion-CLIP model loaded successfully")
# Load dataset
self._load_dataset()
def _load_dataset(self):
"""Load and prepare dataset with primary colors filtering"""
print("📊 Loading dataset...")
df = pd.read_csv(local_dataset_path)
print(f"📊 Loaded {len(df)} samples")
# Filter out rows with NaN values in image path
df_clean = df.dropna(subset=[column_local_image_path])
print(f"📊 After filtering NaN image paths: {len(df_clean)} samples")
# Filter for primary colors only
df_primary = df_clean[df_clean['color'].isin(self.primary_colors)]
print(f"📊 After filtering for primary colors: {len(df_primary)} samples")
# Show color distribution
color_counts = df_primary['color'].value_counts()
print(f"📊 Color distribution:")
for color in self.primary_colors:
count = color_counts.get(color, 0)
print(f" {color}: {count} samples")
# Split for train/val - Limit to 10000 samples
if len(df_primary) > 0:
# Limit to 10000 samples maximum
if len(df_primary) > 10000:
df_primary = df_primary.sample(n=10000, random_state=42)
print(f"📊 Limited to 10000 samples for processing")
_, self.val_df = train_test_split(df_primary, test_size=0.2, random_state=42, stratify=df_primary['color'])
print(f"📊 Validation samples: {len(self.val_df)}")
else:
print("❌ No samples found for primary colors!")
self.val_df = pd.DataFrame()
def create_dataloader(self, dataframe, batch_size=8):
"""Create a dataloader for the dataset"""
dataset = CustomDataset(dataframe, image_size=224)
dataset.set_training_mode(False) # Use validation transforms
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0 # No multiprocessing to avoid memory issues
)
return dataloader
def extract_color_embeddings(self, dataloader, embedding_type='text', model_kind='main', max_samples=10000):
"""
Extract color embeddings (first 16 dimensions) from text or image.
model_kind:
- "main": GAP-CLIP specialized checkpoint (self.main_model)
- "baseline": Fashion-CLIP baseline (self.baseline_model)
"""
all_embeddings = []
all_colors = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"):
if sample_count >= max_samples:
break
images, texts, colors, hierarchies = batch
images = images.to(self.device)
images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
# Select model/processor.
if model_kind == 'baseline':
model = self.baseline_model
processor = self.baseline_processor
else:
model = self.main_model
processor = self.processor
# Process text inputs.
text_inputs = processor(text=texts, padding=True, return_tensors="pt")
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
# Forward pass through main model
outputs = model(**text_inputs, pixel_values=images)
# Extract embeddings based on type
if embedding_type == 'text':
embeddings = outputs.text_embeds
elif embedding_type == 'image':
embeddings = outputs.image_embeds
else:
embeddings = outputs.text_embeds
# Extract only the first 16 dimensions (color embeddings)
color_embeddings = embeddings[:, :self.color_emb_dim]
all_embeddings.append(color_embeddings.cpu().numpy())
all_colors.extend(colors)
sample_count += len(images)
# Clear GPU memory
del images, text_inputs, outputs, embeddings, color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return np.vstack(all_embeddings), all_colors
# Modifiez la méthode predict_colors_from_embeddings
def predict_colors_from_embeddings(self, embeddings, colors):
"""Predict colors from embeddings using centroid-based classification"""
# Create color centroids from training data - only for primary colors
unique_colors = [c for c in self.primary_colors if c in colors]
centroids = {}
for color in unique_colors:
color_indices = [i for i, c in enumerate(colors) if c == color]
if len(color_indices) > 0:
color_embeddings = embeddings[color_indices]
centroids[color] = np.mean(color_embeddings, axis=0)
# Predict colors for all embeddings
predictions = []
for i, embedding in enumerate(embeddings):
# Find closest centroid
best_similarity = -1
predicted_color = None
for color, centroid in centroids.items():
similarity = cosine_similarity([embedding], [centroid])[0][0]
if similarity > best_similarity:
best_similarity = similarity
predicted_color = color
predictions.append(predicted_color)
return predictions
# Modifiez la méthode create_color_confusion_matrix
def create_color_confusion_matrix(self, true_colors, predicted_colors, title="Primary Colors Confusion Matrix"):
"""Create and plot confusion matrix for primary colors"""
# Use only the primary colors in the order specified
unique_colors = [c for c in self.primary_colors if c in true_colors or c in predicted_colors]
# Create confusion matrix
cm = confusion_matrix(true_colors, predicted_colors, labels=unique_colors)
# Calculate accuracy
accuracy = accuracy_score(true_colors, predicted_colors)
# Plot confusion matrix with better formatting
plt.figure(figsize=(14, 12))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=unique_colors, yticklabels=unique_colors,
cbar_kws={'label': 'Number of Samples'})
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)', fontsize=16, fontweight='bold')
plt.ylabel('True Color', fontsize=14, fontweight='bold')
plt.xlabel('Predicted Color', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
return plt.gcf(), accuracy, cm
# Modifiez la méthode evaluate_color_classification
def evaluate_color_classification(self, dataframe, max_samples=10000):
"""Evaluate primary color classification using first 16 dimensions"""
if len(dataframe) == 0:
print("❌ No data available for evaluation")
return None
print(f"\n{'='*60}")
print(f"Evaluating Primary Color Classification (max {max_samples} samples)")
print(f"Target colors: {', '.join(self.primary_colors)}")
print(f"{'='*60}")
# Create dataloader
dataloader = self.create_dataloader(dataframe, batch_size=8)
results = {}
# Evaluate text embeddings
print("🎨 Extracting text color embeddings (first 16 dimensions)...")
text_color_embeddings, color_labels = self.extract_color_embeddings(dataloader, 'text', max_samples)
text_predictions = self.predict_colors_from_embeddings(text_color_embeddings, color_labels)
text_accuracy = accuracy_score(color_labels, text_predictions)
# Create confusion matrix for text
text_fig, text_acc, text_cm = self.create_color_confusion_matrix(
color_labels, text_predictions, "Text Color Embeddings (16D) - Confusion Matrix"
)
results['text'] = {
'embeddings': text_color_embeddings,
'true_colors': color_labels,
'predicted_colors': text_predictions,
'accuracy': text_accuracy,
'confusion_matrix': text_cm,
'figure': text_fig
}
# Clear memory
del text_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate image embeddings
print("🎨 Extracting image color embeddings (first 16 dimensions)...")
image_color_embeddings, color_labels_img = self.extract_color_embeddings(dataloader, 'image', max_samples)
image_predictions = self.predict_colors_from_embeddings(image_color_embeddings, color_labels_img)
image_accuracy = accuracy_score(color_labels_img, image_predictions)
# Create confusion matrix for image
image_fig, image_acc, image_cm = self.create_color_confusion_matrix(
color_labels_img, image_predictions, "Image Color Embeddings (16D) - Confusion Matrix"
)
results['image'] = {
'embeddings': image_color_embeddings,
'true_colors': color_labels_img,
'predicted_colors': image_predictions,
'accuracy': image_accuracy,
'confusion_matrix': image_cm,
'figure': image_fig
}
# Clear memory
del image_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Print detailed results
print(f"\nPrimary Color Classification Results:")
print("-" * 50)
print(f"Text Color Embeddings:")
print(f" Accuracy: {text_accuracy:.4f} ({text_accuracy*100:.1f}%)")
print(f"Image Color Embeddings:")
print(f" Accuracy: {image_accuracy:.4f} ({image_accuracy*100:.1f}%)")
# Show classification report
print(f"\n📊 Detailed Classification Report - Text:")
text_report = classification_report(color_labels, text_predictions, labels=self.primary_colors,
target_names=self.primary_colors, output_dict=True)
for color in self.primary_colors:
if color in text_report:
precision = text_report[color]['precision']
recall = text_report[color]['recall']
f1 = text_report[color]['f1-score']
support = text_report[color]['support']
print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}")
print(f"\n📊 Detailed Classification Report - Image:")
image_report = classification_report(color_labels_img, image_predictions, labels=self.primary_colors,
target_names=self.primary_colors, output_dict=True)
for color in self.primary_colors:
if color in image_report:
precision = image_report[color]['precision']
recall = image_report[color]['recall']
f1 = image_report[color]['f1-score']
support = image_report[color]['support']
print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}")
# Create visualizations
os.makedirs('evaluation/color_evaluation_results', exist_ok=True)
results['text']['figure'].savefig('evaluation/color_evaluation_results/text_color_confusion_matrix.png',
dpi=300, bbox_inches='tight')
results['image']['figure'].savefig('evaluation/color_evaluation_results/image_color_confusion_matrix.png',
dpi=300, bbox_inches='tight')
plt.close(results['text']['figure'])
plt.close(results['image']['figure'])
return results
def create_color_similarity_heatmap(
self,
embeddings,
colors,
embedding_type='text',
save_path='evaluation/color_similarity_results/color_similarity_heatmap.png',
centroid_degradation_strength: float = 0.0,
heatmap_metric: str = "similarity",
annot: bool = True,
mask_diagonal: bool = True,
contrast_percentiles: tuple[float, float] = (5.0, 95.0),
print_stats: bool = True,
):
"""
Create a heatmap of similarities between encoded colors
"""
print(f"🎨 Creating color similarity heatmap for {embedding_type} embeddings...")
unique_colors = [c for c in self.primary_colors if c in colors]
centroids = {}
for color in unique_colors:
color_indices = [i for i, c in enumerate(colors) if c == color]
if len(color_indices) > 0:
color_embeddings = embeddings[color_indices]
centroids[color] = np.mean(color_embeddings, axis=0)
# Degrade colors by mixing each color centroid toward the global centroid.
# This increases cross-color similarity and visually "degrades" the color separation.
centroid_degradation_strength = float(centroid_degradation_strength)
if centroid_degradation_strength > 0 and len(centroids) > 1:
global_centroid = np.mean(np.stack(list(centroids.values())), axis=0)
for c in centroids:
centroids[c] = (1 - centroid_degradation_strength) * centroids[c] + centroid_degradation_strength * global_centroid
similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
for i, color1 in enumerate(unique_colors):
for j, color2 in enumerate(unique_colors):
if i == j:
# Cosine between a vector and itself is 1 (centroids are fixed points).
similarity_matrix[i, j] = 1.0
else:
similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
similarity_matrix[i, j] = similarity
# For visualization: masking diagonal + using off-diagonal auto-contrast
# makes cross-color differences much more visible.
n = len(unique_colors)
mask = np.eye(n, dtype=bool) if mask_diagonal else np.zeros((n, n), dtype=bool)
if print_stats:
off_diag_similarity = similarity_matrix[~mask]
# Most similar off-diagonal pair = where the model confuses colors most.
masked_similarity = np.where(mask, -np.inf, similarity_matrix)
max_i, max_j = np.unravel_index(np.argmax(masked_similarity), similarity_matrix.shape)
# Least similar off-diagonal pair = most separated colors.
masked_similarity_min = np.where(mask, np.inf, similarity_matrix)
min_i, min_j = np.unravel_index(np.argmin(masked_similarity_min), similarity_matrix.shape)
print(
f"📈 {embedding_type.upper()} | off-diagonal cosine similarity: "
f"mean={float(off_diag_similarity.mean()):.3f}, std={float(off_diag_similarity.std()):.3f}"
)
print(
f"📍 {embedding_type.upper()} | most similar pair: "
f"{unique_colors[max_i]}{unique_colors[max_j]} = {float(similarity_matrix[max_i, max_j]):.3f}"
)
print(
f"📍 {embedding_type.upper()} | least similar pair: "
f"{unique_colors[min_i]}{unique_colors[min_j]} = {float(similarity_matrix[min_i, min_j]):.3f}"
)
if heatmap_metric == "similarity":
plot_matrix = similarity_matrix
cbar_label = "Cosine Similarity"
cmap = "RdYlBu_r"
# Use off-diagonal values to compute contrast.
off_diag_vals = plot_matrix[~mask]
elif heatmap_metric == "separation":
# Higher values => colors are less similar (more separated).
plot_matrix = 1.0 - similarity_matrix
cbar_label = "Separation (1 - Cosine Similarity)"
cmap = "magma"
off_diag_vals = plot_matrix[~mask]
else:
raise ValueError(f"Unsupported heatmap_metric: {heatmap_metric}")
# Robust auto-contrast: percentiles avoid single extreme values dominating.
lo_p, hi_p = contrast_percentiles
vmin = float(np.percentile(off_diag_vals, lo_p)) if off_diag_vals.size > 0 else None
vmax = float(np.percentile(off_diag_vals, hi_p)) if off_diag_vals.size > 0 else None
plt.figure(figsize=(12, 10))
heatmap_kwargs = dict(
data=plot_matrix,
mask=mask,
annot=annot,
fmt=".3f" if annot else "",
xticklabels=unique_colors,
yticklabels=unique_colors,
square=True,
cbar_kws={"label": cbar_label},
linewidths=0.5,
)
if heatmap_metric == "similarity":
# Diverging scale centered at 0 to emphasize "opposite" directions.
if vmin is not None and vmax is not None and vmin != vmax:
# TwoSlopeNorm requires: vmin < vcenter < vmax
if vmin < 0.0 < vmax:
vcenter = 0.0
else:
# If all values are one-sided (e.g. all positive), pick midpoint.
vcenter = (vmin + vmax) / 2.0
if vmin < vcenter < vmax:
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
heatmap_kwargs["norm"] = norm
else:
heatmap_kwargs["vmin"] = vmin
heatmap_kwargs["vmax"] = vmax
else:
heatmap_kwargs["vmin"] = vmin
heatmap_kwargs["vmax"] = vmax
else:
# Sequential scale for "separation" (>=0).
heatmap_kwargs["vmin"] = vmin
heatmap_kwargs["vmax"] = vmax
sns.heatmap(cmap=cmap, **heatmap_kwargs)
title_suffix = "separation" if heatmap_metric == "separation" else "similarity"
plt.title(f"Color {title_suffix} ({embedding_type} embeddings)",
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Colors', fontsize=14, fontweight='bold')
plt.ylabel('Colors', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"💾 Heatmap saved: {save_path}")
return plt.gcf(), similarity_matrix
def generate_similarity_heatmaps(
self,
dataloader,
model_kind: str,
max_samples: int,
centroid_degradation_strength: float,
):
"""
Generate and save similarity heatmaps (text + image) for a given model kind.
"""
if model_kind not in {'main', 'baseline'}:
raise ValueError(f"Unsupported model_kind: {model_kind}")
os.makedirs('evaluation/color_similarity_results', exist_ok=True)
print(f"\n🎨 Generating similarity heatmaps for model_kind={model_kind} "
f"(degradation_strength={centroid_degradation_strength})...")
# Text heatmap.
text_embeddings, text_colors = self.extract_color_embeddings(
dataloader,
embedding_type='text',
model_kind=model_kind,
max_samples=max_samples,
)
main_or_baseline = 'gap_clip' if model_kind == 'main' else 'fashion_clip_baseline'
text_save_path = (
'evaluation/color_similarity_results/text_color_similarity_heatmap.png'
if model_kind == 'main'
else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_similarity_heatmap.png'
)
text_fig, _ = self.create_color_similarity_heatmap(
text_embeddings,
text_colors,
embedding_type='text',
save_path=text_save_path,
centroid_degradation_strength=centroid_degradation_strength,
)
plt.close(text_fig)
# Text separation heatmap (more visually sensitive than raw similarity).
text_sep_save_path = (
'evaluation/color_similarity_results/text_color_separation_heatmap.png'
if model_kind == 'main'
else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_separation_heatmap.png'
)
text_sep_fig, _ = self.create_color_similarity_heatmap(
text_embeddings,
text_colors,
embedding_type='text',
save_path=text_sep_save_path,
centroid_degradation_strength=centroid_degradation_strength,
heatmap_metric="separation",
)
plt.close(text_sep_fig)
# Image heatmap.
image_embeddings, image_colors = self.extract_color_embeddings(
dataloader,
embedding_type='image',
model_kind=model_kind,
max_samples=max_samples,
)
image_save_path = (
'evaluation/color_similarity_results/image_color_similarity_heatmap.png'
if model_kind == 'main'
else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_similarity_heatmap.png'
)
image_fig, _ = self.create_color_similarity_heatmap(
image_embeddings,
image_colors,
embedding_type='image',
save_path=image_save_path,
centroid_degradation_strength=centroid_degradation_strength,
)
plt.close(image_fig)
# Image separation heatmap.
image_sep_save_path = (
'evaluation/color_similarity_results/image_color_separation_heatmap.png'
if model_kind == 'main'
else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_separation_heatmap.png'
)
image_sep_fig, _ = self.create_color_similarity_heatmap(
image_embeddings,
image_colors,
embedding_type='image',
save_path=image_sep_save_path,
centroid_degradation_strength=centroid_degradation_strength,
heatmap_metric="separation",
)
plt.close(image_sep_fig)
def create_color_similarity_analysis(self, results):
"""
Complete analysis of similarities between colors for text and image embeddings
"""
print(f"\n{'='*60}")
print("🎨 ANALYSIS OF SIMILARITIES BETWEEN COLORS")
print(f"{'='*60}")
os.makedirs('evaluation/color_similarity_results', exist_ok=True)
similarity_results = {}
if 'text' in results:
print("\n📝 Analyse des similarités - Text Embeddings:")
text_fig, text_similarity_matrix = self.create_color_similarity_heatmap(
results['text']['embeddings'],
results['text']['true_colors'],
'text',
'evaluation/color_similarity_results/text_color_similarity_heatmap.png'
)
similarity_results['text'] = {
'similarity_matrix': text_similarity_matrix,
'figure': text_fig
}
plt.close(text_fig)
# Analyser les embeddings image
if 'image' in results:
print("\n🖼️ Analyse des similarités - Image Embeddings:")
image_fig, image_similarity_matrix = self.create_color_similarity_heatmap(
results['image']['embeddings'],
results['image']['true_colors'],
'image',
'evaluation/color_similarity_results/image_color_similarity_heatmap.png'
)
similarity_results['image'] = {
'similarity_matrix': image_similarity_matrix,
'figure': image_fig
}
plt.close(image_fig)
# Analyser les similarités les plus élevées et les plus faibles
self._analyze_similarity_patterns(similarity_results)
return similarity_results
def _analyze_similarity_patterns(self, similarity_results):
"""
Analyse les patterns de similarité entre les couleurs
"""
print(f"\n�� ANALYSE DES PATTERNS DE SIMILARITÉ")
print("-" * 50)
for embedding_type, data in similarity_results.items():
matrix = data['similarity_matrix']
unique_colors = [c for c in self.primary_colors if c in [f"color_{i}" for i in range(len(matrix))]]
print(f"\n{embedding_type.upper()} Embeddings:")
# Trouver les paires les plus similaires (hors diagonale)
n = len(matrix)
similarities = []
for i in range(n):
for j in range(i+1, n): # Éviter la diagonale et la redondance
similarities.append((i, j, matrix[i, j]))
# Trier par similarité décroissante
similarities.sort(key=lambda x: x[2], reverse=True)
print("🔗 Couleurs les plus similaires:")
for i, (idx1, idx2, sim) in enumerate(similarities[:5]):
color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}"
color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}"
print(f" {i+1}. {color1}{color2}: {sim:.3f}")
print("🔗 Couleurs les moins similaires:")
for i, (idx1, idx2, sim) in enumerate(similarities[-5:]):
color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}"
color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}"
print(f" {i+1}. {color1}{color2}: {sim:.3f}")
# Calculer la similarité moyenne
off_diagonal = matrix[np.triu_indices_from(matrix, k=1)]
mean_similarity = np.mean(off_diagonal)
std_similarity = np.std(off_diagonal)
print(f"📈 Similarité moyenne: {mean_similarity:.3f} ± {std_similarity:.3f}")
class CustomDataset(Dataset):
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe
self.image_size = image_size
# Transforms for validation (no augmentation)
self.val_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.training_mode = True
def set_training_mode(self, training=True):
self.training_mode = training
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
image_data = row[column_local_image_path]
image = Image.open(image_data).convert("RGB")
# Apply validation transform
image = self.val_transform(image)
# Get text and labels
description = row['text']
color = row['color']
hierarchy = row['hierarchy']
return image, description, color, hierarchy
# Modifiez la section main
if __name__ == "__main__":
print("🚀 Starting Primary Color Encoding and Similarity Analysis")
print("="*70)
print(f"Target Primary Colors: {', '.join(PRIMARY_COLORS)}")
print("="*70)
# Initialize color encoder
color_encoder = ColorEncoder(
main_model_path=main_model_path,
device=device
)
# Evaluate primary color classification for the main model (keeps previous behavior).
results = color_encoder.evaluate_color_classification(
color_encoder.val_df,
max_samples=10000,
)
if not results:
print("❌ No results generated - check if primary colors exist in dataset")
raise SystemExit(1)
print(f"\n✅ Primary color encoding and confusion matrix generation completed!")
print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory")
print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%")
print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%")
# Heatmaps with additional centroid degradation (main model + baseline).
dataloader = color_encoder.create_dataloader(color_encoder.val_df, batch_size=8)
max_samples = 10000
centroid_degradation_strength = COLOR_CENTROID_DEGRADATION_STRENGTH
# Your model (GAP-CLIP main checkpoint): overwrites the existing heatmap filenames.
color_encoder.generate_similarity_heatmaps(
dataloader=dataloader,
model_kind='main',
max_samples=max_samples,
centroid_degradation_strength=centroid_degradation_strength,
)
# Baseline Fashion-CLIP: saved as fashion_clip_baseline_* heatmaps.
color_encoder.generate_similarity_heatmaps(
dataloader=dataloader,
model_kind='baseline',
max_samples=max_samples,
centroid_degradation_strength=centroid_degradation_strength,
)
print("\n✅ Color similarity analysis completed!")
print("📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory")