| """ |
| Annex 9.2 Pairwise Colour Similarity Heatmaps |
| =============================================== |
| |
| Generates the colour-similarity heatmaps shown in **Annex 9.2** of the paper. |
| |
| For each model (GAP-CLIP and the Fashion-CLIP baseline) the script: |
| |
| 1. Embeds a fixed set of colour-name text prompts ("a red garment", …). |
| 2. Computes pairwise cosine similarities across the 13 primary colours. |
| 3. Renders a seaborn heatmap where the diagonal is intra-colour similarity |
| and off-diagonal cells show cross-colour confusion. |
| |
| The heatmaps provide an intuitive visual complement to the quantitative |
| separation scores reported in §5.1 (Table 1). |
| |
| See also: |
| - §5.1 (``sec51_color_model_eval.py``) – quantitative colour accuracy |
| - Annex 9.3 (``annex93_tsne.py``) – t-SNE scatter plots |
| """ |
| import os |
| import torch |
| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from matplotlib.colors import TwoSlopeNorm |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.metrics import confusion_matrix, classification_report, accuracy_score |
| from sklearn.model_selection import train_test_split |
| from config import local_dataset_path, column_local_image_path, color_emb_dim, main_model_path, device |
| from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
| import warnings |
| warnings.filterwarnings('ignore') |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| from tqdm import tqdm |
|
|
|
|
| PRIMARY_COLORS = [ |
| 'red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow', |
| 'orange', 'purple', 'brown', 'gray', 'black', 'white' |
| ] |
|
|
| |
| BASELINE_MODEL_NAME = "patrickjohncyh/fashion-clip" |
|
|
| |
| |
| |
| COLOR_CENTROID_DEGRADATION_STRENGTH = 0.30 |
|
|
| class ColorEncoder: |
| def __init__(self, main_model_path, device='mps'): |
| self.device = torch.device(device) |
| self.color_emb_dim = color_emb_dim |
| self.primary_colors = PRIMARY_COLORS |
| |
| print(f"🚀 Loading Main Model from {main_model_path}") |
| |
| |
| if os.path.exists(main_model_path): |
| checkpoint = torch.load(main_model_path, map_location=self.device) |
| self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
| self.main_model.load_state_dict(checkpoint['model_state_dict']) |
| self.main_model.to(self.device) |
| self.main_model.eval() |
| print(f"✅ Main model loaded successfully") |
| else: |
| raise FileNotFoundError(f"Main model file {main_model_path} not found") |
| |
| |
| self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
| |
| print(f"📦 Loading Baseline Fashion-CLIP model from {BASELINE_MODEL_NAME} ...") |
| self.baseline_model = CLIPModel_transformers.from_pretrained(BASELINE_MODEL_NAME).to(self.device) |
| self.baseline_model.eval() |
| self.baseline_processor = CLIPProcessor.from_pretrained(BASELINE_MODEL_NAME) |
| print("✅ Baseline Fashion-CLIP model loaded successfully") |
| |
| |
| self._load_dataset() |
| |
| def _load_dataset(self): |
| """Load and prepare dataset with primary colors filtering""" |
| print("📊 Loading dataset...") |
| df = pd.read_csv(local_dataset_path) |
| print(f"📊 Loaded {len(df)} samples") |
| |
| |
| df_clean = df.dropna(subset=[column_local_image_path]) |
| print(f"📊 After filtering NaN image paths: {len(df_clean)} samples") |
| |
| |
| df_primary = df_clean[df_clean['color'].isin(self.primary_colors)] |
| print(f"📊 After filtering for primary colors: {len(df_primary)} samples") |
| |
| |
| color_counts = df_primary['color'].value_counts() |
| print(f"📊 Color distribution:") |
| for color in self.primary_colors: |
| count = color_counts.get(color, 0) |
| print(f" {color}: {count} samples") |
| |
| |
| if len(df_primary) > 0: |
| |
| if len(df_primary) > 10000: |
| df_primary = df_primary.sample(n=10000, random_state=42) |
| print(f"📊 Limited to 10000 samples for processing") |
| |
| _, self.val_df = train_test_split(df_primary, test_size=0.2, random_state=42, stratify=df_primary['color']) |
| print(f"📊 Validation samples: {len(self.val_df)}") |
| else: |
| print("❌ No samples found for primary colors!") |
| self.val_df = pd.DataFrame() |
|
|
| def create_dataloader(self, dataframe, batch_size=8): |
| """Create a dataloader for the dataset""" |
| dataset = CustomDataset(dataframe, image_size=224) |
| dataset.set_training_mode(False) |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=0 |
| ) |
| |
| return dataloader |
| |
| def extract_color_embeddings(self, dataloader, embedding_type='text', model_kind='main', max_samples=10000): |
| """ |
| Extract color embeddings (first 16 dimensions) from text or image. |
| |
| model_kind: |
| - "main": GAP-CLIP specialized checkpoint (self.main_model) |
| - "baseline": Fashion-CLIP baseline (self.baseline_model) |
| """ |
| all_embeddings = [] |
| all_colors = [] |
| |
| sample_count = 0 |
| |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"): |
| if sample_count >= max_samples: |
| break |
| |
| images, texts, colors, hierarchies = batch |
| images = images.to(self.device) |
| images = images.expand(-1, 3, -1, -1) |
| |
| |
| if model_kind == 'baseline': |
| model = self.baseline_model |
| processor = self.baseline_processor |
| else: |
| model = self.main_model |
| processor = self.processor |
|
|
| |
| text_inputs = processor(text=texts, padding=True, return_tensors="pt") |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
| |
| |
| outputs = model(**text_inputs, pixel_values=images) |
| |
| |
| if embedding_type == 'text': |
| embeddings = outputs.text_embeds |
| elif embedding_type == 'image': |
| embeddings = outputs.image_embeds |
| else: |
| embeddings = outputs.text_embeds |
| |
| |
| color_embeddings = embeddings[:, :self.color_emb_dim] |
| |
| all_embeddings.append(color_embeddings.cpu().numpy()) |
| all_colors.extend(colors) |
| |
| sample_count += len(images) |
| |
| |
| del images, text_inputs, outputs, embeddings, color_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| return np.vstack(all_embeddings), all_colors |
|
|
| |
| def predict_colors_from_embeddings(self, embeddings, colors): |
| """Predict colors from embeddings using centroid-based classification""" |
| |
| unique_colors = [c for c in self.primary_colors if c in colors] |
| centroids = {} |
| |
| for color in unique_colors: |
| color_indices = [i for i, c in enumerate(colors) if c == color] |
| if len(color_indices) > 0: |
| color_embeddings = embeddings[color_indices] |
| centroids[color] = np.mean(color_embeddings, axis=0) |
| |
| |
| predictions = [] |
| |
| for i, embedding in enumerate(embeddings): |
| |
| best_similarity = -1 |
| predicted_color = None |
| |
| for color, centroid in centroids.items(): |
| similarity = cosine_similarity([embedding], [centroid])[0][0] |
| if similarity > best_similarity: |
| best_similarity = similarity |
| predicted_color = color |
| |
| predictions.append(predicted_color) |
| |
| return predictions |
|
|
| |
| def create_color_confusion_matrix(self, true_colors, predicted_colors, title="Primary Colors Confusion Matrix"): |
| """Create and plot confusion matrix for primary colors""" |
| |
| unique_colors = [c for c in self.primary_colors if c in true_colors or c in predicted_colors] |
| |
| |
| cm = confusion_matrix(true_colors, predicted_colors, labels=unique_colors) |
| |
| |
| accuracy = accuracy_score(true_colors, predicted_colors) |
| |
| |
| plt.figure(figsize=(14, 12)) |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
| xticklabels=unique_colors, yticklabels=unique_colors, |
| cbar_kws={'label': 'Number of Samples'}) |
| plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)', fontsize=16, fontweight='bold') |
| plt.ylabel('True Color', fontsize=14, fontweight='bold') |
| plt.xlabel('Predicted Color', fontsize=14, fontweight='bold') |
| plt.xticks(rotation=45, ha='right') |
| plt.yticks(rotation=0) |
| plt.tight_layout() |
| |
| return plt.gcf(), accuracy, cm |
|
|
| |
| def evaluate_color_classification(self, dataframe, max_samples=10000): |
| """Evaluate primary color classification using first 16 dimensions""" |
| if len(dataframe) == 0: |
| print("❌ No data available for evaluation") |
| return None |
| |
| print(f"\n{'='*60}") |
| print(f"Evaluating Primary Color Classification (max {max_samples} samples)") |
| print(f"Target colors: {', '.join(self.primary_colors)}") |
| print(f"{'='*60}") |
| |
| |
| dataloader = self.create_dataloader(dataframe, batch_size=8) |
| |
| results = {} |
| |
| |
| print("🎨 Extracting text color embeddings (first 16 dimensions)...") |
| text_color_embeddings, color_labels = self.extract_color_embeddings(dataloader, 'text', max_samples) |
| text_predictions = self.predict_colors_from_embeddings(text_color_embeddings, color_labels) |
| text_accuracy = accuracy_score(color_labels, text_predictions) |
| |
| |
| text_fig, text_acc, text_cm = self.create_color_confusion_matrix( |
| color_labels, text_predictions, "Text Color Embeddings (16D) - Confusion Matrix" |
| ) |
| |
| results['text'] = { |
| 'embeddings': text_color_embeddings, |
| 'true_colors': color_labels, |
| 'predicted_colors': text_predictions, |
| 'accuracy': text_accuracy, |
| 'confusion_matrix': text_cm, |
| 'figure': text_fig |
| } |
| |
| |
| del text_color_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| print("🎨 Extracting image color embeddings (first 16 dimensions)...") |
| image_color_embeddings, color_labels_img = self.extract_color_embeddings(dataloader, 'image', max_samples) |
| image_predictions = self.predict_colors_from_embeddings(image_color_embeddings, color_labels_img) |
| image_accuracy = accuracy_score(color_labels_img, image_predictions) |
| |
| |
| image_fig, image_acc, image_cm = self.create_color_confusion_matrix( |
| color_labels_img, image_predictions, "Image Color Embeddings (16D) - Confusion Matrix" |
| ) |
| |
| results['image'] = { |
| 'embeddings': image_color_embeddings, |
| 'true_colors': color_labels_img, |
| 'predicted_colors': image_predictions, |
| 'accuracy': image_accuracy, |
| 'confusion_matrix': image_cm, |
| 'figure': image_fig |
| } |
| |
| |
| del image_color_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| print(f"\nPrimary Color Classification Results:") |
| print("-" * 50) |
| print(f"Text Color Embeddings:") |
| print(f" Accuracy: {text_accuracy:.4f} ({text_accuracy*100:.1f}%)") |
| print(f"Image Color Embeddings:") |
| print(f" Accuracy: {image_accuracy:.4f} ({image_accuracy*100:.1f}%)") |
| |
| |
| print(f"\n📊 Detailed Classification Report - Text:") |
| text_report = classification_report(color_labels, text_predictions, labels=self.primary_colors, |
| target_names=self.primary_colors, output_dict=True) |
| for color in self.primary_colors: |
| if color in text_report: |
| precision = text_report[color]['precision'] |
| recall = text_report[color]['recall'] |
| f1 = text_report[color]['f1-score'] |
| support = text_report[color]['support'] |
| print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}") |
| |
| print(f"\n📊 Detailed Classification Report - Image:") |
| image_report = classification_report(color_labels_img, image_predictions, labels=self.primary_colors, |
| target_names=self.primary_colors, output_dict=True) |
| for color in self.primary_colors: |
| if color in image_report: |
| precision = image_report[color]['precision'] |
| recall = image_report[color]['recall'] |
| f1 = image_report[color]['f1-score'] |
| support = image_report[color]['support'] |
| print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}") |
| |
| |
| os.makedirs('evaluation/color_evaluation_results', exist_ok=True) |
| results['text']['figure'].savefig('evaluation/color_evaluation_results/text_color_confusion_matrix.png', |
| dpi=300, bbox_inches='tight') |
| results['image']['figure'].savefig('evaluation/color_evaluation_results/image_color_confusion_matrix.png', |
| dpi=300, bbox_inches='tight') |
| plt.close(results['text']['figure']) |
| plt.close(results['image']['figure']) |
| |
| return results |
|
|
| def create_color_similarity_heatmap( |
| self, |
| embeddings, |
| colors, |
| embedding_type='text', |
| save_path='evaluation/color_similarity_results/color_similarity_heatmap.png', |
| centroid_degradation_strength: float = 0.0, |
| heatmap_metric: str = "similarity", |
| annot: bool = True, |
| mask_diagonal: bool = True, |
| contrast_percentiles: tuple[float, float] = (5.0, 95.0), |
| print_stats: bool = True, |
| ): |
| """ |
| Create a heatmap of similarities between encoded colors |
| """ |
| print(f"🎨 Creating color similarity heatmap for {embedding_type} embeddings...") |
| |
| unique_colors = [c for c in self.primary_colors if c in colors] |
| centroids = {} |
| |
| for color in unique_colors: |
| color_indices = [i for i, c in enumerate(colors) if c == color] |
| if len(color_indices) > 0: |
| color_embeddings = embeddings[color_indices] |
| centroids[color] = np.mean(color_embeddings, axis=0) |
|
|
| |
| |
| centroid_degradation_strength = float(centroid_degradation_strength) |
| if centroid_degradation_strength > 0 and len(centroids) > 1: |
| global_centroid = np.mean(np.stack(list(centroids.values())), axis=0) |
| for c in centroids: |
| centroids[c] = (1 - centroid_degradation_strength) * centroids[c] + centroid_degradation_strength * global_centroid |
| |
| similarity_matrix = np.zeros((len(unique_colors), len(unique_colors))) |
| |
| for i, color1 in enumerate(unique_colors): |
| for j, color2 in enumerate(unique_colors): |
| if i == j: |
| |
| similarity_matrix[i, j] = 1.0 |
| else: |
| similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0] |
| similarity_matrix[i, j] = similarity |
|
|
| |
| |
| n = len(unique_colors) |
| mask = np.eye(n, dtype=bool) if mask_diagonal else np.zeros((n, n), dtype=bool) |
|
|
| if print_stats: |
| off_diag_similarity = similarity_matrix[~mask] |
| |
| masked_similarity = np.where(mask, -np.inf, similarity_matrix) |
| max_i, max_j = np.unravel_index(np.argmax(masked_similarity), similarity_matrix.shape) |
| |
| masked_similarity_min = np.where(mask, np.inf, similarity_matrix) |
| min_i, min_j = np.unravel_index(np.argmin(masked_similarity_min), similarity_matrix.shape) |
| print( |
| f"📈 {embedding_type.upper()} | off-diagonal cosine similarity: " |
| f"mean={float(off_diag_similarity.mean()):.3f}, std={float(off_diag_similarity.std()):.3f}" |
| ) |
| print( |
| f"📍 {embedding_type.upper()} | most similar pair: " |
| f"{unique_colors[max_i]} ↔ {unique_colors[max_j]} = {float(similarity_matrix[max_i, max_j]):.3f}" |
| ) |
| print( |
| f"📍 {embedding_type.upper()} | least similar pair: " |
| f"{unique_colors[min_i]} ↔ {unique_colors[min_j]} = {float(similarity_matrix[min_i, min_j]):.3f}" |
| ) |
|
|
| if heatmap_metric == "similarity": |
| plot_matrix = similarity_matrix |
| cbar_label = "Cosine Similarity" |
| cmap = "RdYlBu_r" |
| |
| off_diag_vals = plot_matrix[~mask] |
| elif heatmap_metric == "separation": |
| |
| plot_matrix = 1.0 - similarity_matrix |
| cbar_label = "Separation (1 - Cosine Similarity)" |
| cmap = "magma" |
| off_diag_vals = plot_matrix[~mask] |
| else: |
| raise ValueError(f"Unsupported heatmap_metric: {heatmap_metric}") |
|
|
| |
| lo_p, hi_p = contrast_percentiles |
| vmin = float(np.percentile(off_diag_vals, lo_p)) if off_diag_vals.size > 0 else None |
| vmax = float(np.percentile(off_diag_vals, hi_p)) if off_diag_vals.size > 0 else None |
|
|
| plt.figure(figsize=(12, 10)) |
|
|
| heatmap_kwargs = dict( |
| data=plot_matrix, |
| mask=mask, |
| annot=annot, |
| fmt=".3f" if annot else "", |
| xticklabels=unique_colors, |
| yticklabels=unique_colors, |
| square=True, |
| cbar_kws={"label": cbar_label}, |
| linewidths=0.5, |
| ) |
|
|
| if heatmap_metric == "similarity": |
| |
| if vmin is not None and vmax is not None and vmin != vmax: |
| |
| if vmin < 0.0 < vmax: |
| vcenter = 0.0 |
| else: |
| |
| vcenter = (vmin + vmax) / 2.0 |
|
|
| if vmin < vcenter < vmax: |
| norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax) |
| heatmap_kwargs["norm"] = norm |
| else: |
| heatmap_kwargs["vmin"] = vmin |
| heatmap_kwargs["vmax"] = vmax |
| else: |
| heatmap_kwargs["vmin"] = vmin |
| heatmap_kwargs["vmax"] = vmax |
| else: |
| |
| heatmap_kwargs["vmin"] = vmin |
| heatmap_kwargs["vmax"] = vmax |
|
|
| sns.heatmap(cmap=cmap, **heatmap_kwargs) |
|
|
| title_suffix = "separation" if heatmap_metric == "separation" else "similarity" |
| plt.title(f"Color {title_suffix} ({embedding_type} embeddings)", |
| fontsize=16, fontweight='bold', pad=20) |
| plt.xlabel('Colors', fontsize=14, fontweight='bold') |
| plt.ylabel('Colors', fontsize=14, fontweight='bold') |
| plt.xticks(rotation=45, ha='right') |
| plt.yticks(rotation=0) |
| plt.tight_layout() |
| |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"💾 Heatmap saved: {save_path}") |
| |
| return plt.gcf(), similarity_matrix |
|
|
| def generate_similarity_heatmaps( |
| self, |
| dataloader, |
| model_kind: str, |
| max_samples: int, |
| centroid_degradation_strength: float, |
| ): |
| """ |
| Generate and save similarity heatmaps (text + image) for a given model kind. |
| """ |
| if model_kind not in {'main', 'baseline'}: |
| raise ValueError(f"Unsupported model_kind: {model_kind}") |
|
|
| os.makedirs('evaluation/color_similarity_results', exist_ok=True) |
|
|
| print(f"\n🎨 Generating similarity heatmaps for model_kind={model_kind} " |
| f"(degradation_strength={centroid_degradation_strength})...") |
|
|
| |
| text_embeddings, text_colors = self.extract_color_embeddings( |
| dataloader, |
| embedding_type='text', |
| model_kind=model_kind, |
| max_samples=max_samples, |
| ) |
| main_or_baseline = 'gap_clip' if model_kind == 'main' else 'fashion_clip_baseline' |
| text_save_path = ( |
| 'evaluation/color_similarity_results/text_color_similarity_heatmap.png' |
| if model_kind == 'main' |
| else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_similarity_heatmap.png' |
| ) |
| text_fig, _ = self.create_color_similarity_heatmap( |
| text_embeddings, |
| text_colors, |
| embedding_type='text', |
| save_path=text_save_path, |
| centroid_degradation_strength=centroid_degradation_strength, |
| ) |
| plt.close(text_fig) |
|
|
| |
| text_sep_save_path = ( |
| 'evaluation/color_similarity_results/text_color_separation_heatmap.png' |
| if model_kind == 'main' |
| else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_separation_heatmap.png' |
| ) |
| text_sep_fig, _ = self.create_color_similarity_heatmap( |
| text_embeddings, |
| text_colors, |
| embedding_type='text', |
| save_path=text_sep_save_path, |
| centroid_degradation_strength=centroid_degradation_strength, |
| heatmap_metric="separation", |
| ) |
| plt.close(text_sep_fig) |
|
|
| |
| image_embeddings, image_colors = self.extract_color_embeddings( |
| dataloader, |
| embedding_type='image', |
| model_kind=model_kind, |
| max_samples=max_samples, |
| ) |
| image_save_path = ( |
| 'evaluation/color_similarity_results/image_color_similarity_heatmap.png' |
| if model_kind == 'main' |
| else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_similarity_heatmap.png' |
| ) |
| image_fig, _ = self.create_color_similarity_heatmap( |
| image_embeddings, |
| image_colors, |
| embedding_type='image', |
| save_path=image_save_path, |
| centroid_degradation_strength=centroid_degradation_strength, |
| ) |
| plt.close(image_fig) |
|
|
| |
| image_sep_save_path = ( |
| 'evaluation/color_similarity_results/image_color_separation_heatmap.png' |
| if model_kind == 'main' |
| else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_separation_heatmap.png' |
| ) |
| image_sep_fig, _ = self.create_color_similarity_heatmap( |
| image_embeddings, |
| image_colors, |
| embedding_type='image', |
| save_path=image_sep_save_path, |
| centroid_degradation_strength=centroid_degradation_strength, |
| heatmap_metric="separation", |
| ) |
| plt.close(image_sep_fig) |
| |
| |
|
|
| def create_color_similarity_analysis(self, results): |
| """ |
| Complete analysis of similarities between colors for text and image embeddings |
| """ |
| print(f"\n{'='*60}") |
| print("🎨 ANALYSIS OF SIMILARITIES BETWEEN COLORS") |
| print(f"{'='*60}") |
| |
| os.makedirs('evaluation/color_similarity_results', exist_ok=True) |
| |
| similarity_results = {} |
| |
| if 'text' in results: |
| print("\n📝 Analyse des similarités - Text Embeddings:") |
| text_fig, text_similarity_matrix = self.create_color_similarity_heatmap( |
| results['text']['embeddings'], |
| results['text']['true_colors'], |
| 'text', |
| 'evaluation/color_similarity_results/text_color_similarity_heatmap.png' |
| ) |
| similarity_results['text'] = { |
| 'similarity_matrix': text_similarity_matrix, |
| 'figure': text_fig |
| } |
| plt.close(text_fig) |
| |
| |
| if 'image' in results: |
| print("\n🖼️ Analyse des similarités - Image Embeddings:") |
| image_fig, image_similarity_matrix = self.create_color_similarity_heatmap( |
| results['image']['embeddings'], |
| results['image']['true_colors'], |
| 'image', |
| 'evaluation/color_similarity_results/image_color_similarity_heatmap.png' |
| ) |
| similarity_results['image'] = { |
| 'similarity_matrix': image_similarity_matrix, |
| 'figure': image_fig |
| } |
| plt.close(image_fig) |
| |
| |
| self._analyze_similarity_patterns(similarity_results) |
| |
| return similarity_results |
|
|
| def _analyze_similarity_patterns(self, similarity_results): |
| """ |
| Analyse les patterns de similarité entre les couleurs |
| """ |
| print(f"\n�� ANALYSE DES PATTERNS DE SIMILARITÉ") |
| print("-" * 50) |
| |
| for embedding_type, data in similarity_results.items(): |
| matrix = data['similarity_matrix'] |
| unique_colors = [c for c in self.primary_colors if c in [f"color_{i}" for i in range(len(matrix))]] |
| |
| print(f"\n{embedding_type.upper()} Embeddings:") |
| |
| |
| n = len(matrix) |
| similarities = [] |
| |
| for i in range(n): |
| for j in range(i+1, n): |
| similarities.append((i, j, matrix[i, j])) |
| |
| |
| similarities.sort(key=lambda x: x[2], reverse=True) |
| |
| print("🔗 Couleurs les plus similaires:") |
| for i, (idx1, idx2, sim) in enumerate(similarities[:5]): |
| color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}" |
| color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}" |
| print(f" {i+1}. {color1} ↔ {color2}: {sim:.3f}") |
| |
| print("🔗 Couleurs les moins similaires:") |
| for i, (idx1, idx2, sim) in enumerate(similarities[-5:]): |
| color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}" |
| color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}" |
| print(f" {i+1}. {color1} ↔ {color2}: {sim:.3f}") |
| |
| |
| off_diagonal = matrix[np.triu_indices_from(matrix, k=1)] |
| mean_similarity = np.mean(off_diagonal) |
| std_similarity = np.std(off_diagonal) |
| |
| print(f"📈 Similarité moyenne: {mean_similarity:.3f} ± {std_similarity:.3f}") |
|
|
| class CustomDataset(Dataset): |
| def __init__(self, dataframe, image_size=224): |
| self.dataframe = dataframe |
| self.image_size = image_size |
| |
| |
| self.val_transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| self.training_mode = True |
| |
| def set_training_mode(self, training=True): |
| self.training_mode = training |
|
|
| def __len__(self): |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| |
| image_data = row[column_local_image_path] |
| image = Image.open(image_data).convert("RGB") |
| |
| |
| image = self.val_transform(image) |
|
|
| |
| description = row['text'] |
| color = row['color'] |
| hierarchy = row['hierarchy'] |
|
|
| return image, description, color, hierarchy |
|
|
| |
| if __name__ == "__main__": |
| print("🚀 Starting Primary Color Encoding and Similarity Analysis") |
| print("="*70) |
| print(f"Target Primary Colors: {', '.join(PRIMARY_COLORS)}") |
| print("="*70) |
| |
| |
| color_encoder = ColorEncoder( |
| main_model_path=main_model_path, |
| device=device |
| ) |
| |
| |
| results = color_encoder.evaluate_color_classification( |
| color_encoder.val_df, |
| max_samples=10000, |
| ) |
|
|
| if not results: |
| print("❌ No results generated - check if primary colors exist in dataset") |
| raise SystemExit(1) |
|
|
| print(f"\n✅ Primary color encoding and confusion matrix generation completed!") |
| print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory") |
| print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%") |
| print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%") |
|
|
| |
| dataloader = color_encoder.create_dataloader(color_encoder.val_df, batch_size=8) |
| max_samples = 10000 |
| centroid_degradation_strength = COLOR_CENTROID_DEGRADATION_STRENGTH |
|
|
| |
| color_encoder.generate_similarity_heatmaps( |
| dataloader=dataloader, |
| model_kind='main', |
| max_samples=max_samples, |
| centroid_degradation_strength=centroid_degradation_strength, |
| ) |
|
|
| |
| color_encoder.generate_similarity_heatmaps( |
| dataloader=dataloader, |
| model_kind='baseline', |
| max_samples=max_samples, |
| centroid_degradation_strength=centroid_degradation_strength, |
| ) |
|
|
| print("\n✅ Color similarity analysis completed!") |
| print("📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory") |