| """ |
| Section 5.1 — Color Model Evaluation (Table 1) |
| =============================================== |
| |
| Evaluates the standalone 16D color model (ColorCLIP) on accuracy and |
| separation scores across: |
| - KAGL Marqo (external, 10k items, 46 colors) |
| - Local validation dataset (internal, 5k items, 11 colors) |
| |
| Metrics reported match Table 1 in the paper: |
| - Text/image embedding NN accuracy |
| - Text/image embedding separation score (intra - inter class distance) |
| |
| Compared against Fashion-CLIP baseline (patrickjohncyh/fashion-clip). |
| |
| Run directly: |
| python sec51_color_model_eval.py |
| |
| Paper reference: Section 5.1, Table 1. |
| """ |
|
|
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import classification_report, accuracy_score |
| from torch.utils.data import DataLoader |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from config import ( |
| color_model_path, |
| color_emb_dim, |
| main_emb_dim, |
| ) |
| from utils.datasets import ( |
| load_kaggle_marqo_dataset, |
| load_local_validation_dataset, |
| collate_fn_filter_none, |
| ) |
| from utils.embeddings import extract_clip_embeddings, extract_color_model_embeddings |
| from utils.metrics import ( |
| compute_similarity_metrics, |
| predict_labels_from_embeddings, |
| create_confusion_matrix, |
| ) |
| from utils.model_loader import load_color_model, load_baseline_fashion_clip |
|
|
|
|
| class ColorEvaluator: |
| """Evaluate color 16 embeddings""" |
|
|
| def __init__( |
| self, |
| device='mps', |
| directory="figures/confusion_matrices/cm_color", |
| baseline_model=None, |
| baseline_processor=None, |
| color_model=None, |
| kaggle_raw_df=None, |
| local_raw_df=None, |
| ): |
| self.device = torch.device(device) |
| self.directory = directory |
| self.color_emb_dim = color_emb_dim |
| self.main_emb_dim = main_emb_dim |
| self.kaggle_raw_df = kaggle_raw_df |
| self.local_raw_df = local_raw_df |
| os.makedirs(self.directory, exist_ok=True) |
|
|
| |
| if baseline_model is not None and baseline_processor is not None: |
| self.baseline_model = baseline_model |
| self.baseline_processor = baseline_processor |
| else: |
| print("Loading baseline Fashion CLIP model...") |
| self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) |
| print("Baseline Fashion CLIP model loaded successfully") |
|
|
| |
| if color_model is not None: |
| self.color_model = color_model |
| else: |
| self.color_model, _ = load_color_model(color_model_path, self.device) |
|
|
| def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"): |
| """ |
| Evaluate classification performance and create confusion matrix. |
| |
| Args: |
| embeddings: Embeddings |
| labels: True labels |
| embedding_type: Type of embeddings for display |
| label_type: Type of labels (Color) |
| """ |
|
|
| predictions = predict_labels_from_embeddings(embeddings, labels) |
|
|
| |
| valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions)) |
| if label is not None and pred is not None] |
|
|
| if len(valid_indices) == 0: |
| print(f"Warning: No valid labels/predictions found (all are None)") |
| return { |
| 'accuracy': 0.0, |
| 'predictions': predictions, |
| 'confusion_matrix': None, |
| 'classification_report': None, |
| 'figure': None, |
| } |
|
|
| filtered_labels = [labels[i] for i in valid_indices] |
| filtered_predictions = [predictions[i] for i in valid_indices] |
|
|
| accuracy = accuracy_score(filtered_labels, filtered_predictions) |
| fig, _, cm = create_confusion_matrix( |
| filtered_labels, filtered_predictions, |
| embedding_type, |
| label_type |
| ) |
| unique_labels = sorted(list(set(filtered_labels))) |
| report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) |
| return { |
| 'accuracy': accuracy, |
| 'predictions': predictions, |
| 'confusion_matrix': cm, |
| 'classification_report': report, |
| 'figure': fig, |
| } |
|
|
|
|
| def evaluate_kaggle_marqo(self, max_samples): |
| """Evaluate both color embeddings on KAGL Marqo dataset""" |
| print(f"\n{'='*60}") |
| print("Evaluating KAGL Marqo Dataset with Color embeddings") |
| print(f"Max samples: {max_samples}") |
| print(f"{'='*60}") |
|
|
| kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df) |
| if kaggle_dataset is None: |
| print("Failed to load KAGL dataset") |
| return None |
|
|
| dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) |
|
|
| results = {} |
|
|
| |
| print("\nExtracting color model embeddings...") |
| text_full_embeddings, text_colors_full = extract_color_model_embeddings( |
| self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples |
| ) |
| image_full_embeddings, image_colors_full = extract_color_model_embeddings( |
| self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples |
| ) |
| text_color_metrics = compute_similarity_metrics(text_full_embeddings, text_colors_full) |
| text_color_class = self.evaluate_classification_performance( |
| text_full_embeddings, text_colors_full, |
| "KAGL Marqo, text, color confusion matrix", "Color", |
| ) |
| text_color_metrics.update(text_color_class) |
| results['text_color'] = text_color_metrics |
| image_color_metrics = compute_similarity_metrics(image_full_embeddings, image_colors_full) |
| image_color_class = self.evaluate_classification_performance( |
| image_full_embeddings, image_colors_full, |
| "KAGL Marqo, image, color confusion matrix", "Color", |
| ) |
| image_color_metrics.update(image_color_class) |
| results['image_color'] = image_color_metrics |
| del text_full_embeddings, image_full_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| os.makedirs(self.directory, exist_ok=True) |
| for key in ['text_color', 'image_color']: |
| results[key]['figure'].savefig( |
| f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", |
| dpi=300, |
| bbox_inches='tight', |
| ) |
| plt.close(results[key]['figure']) |
|
|
| return results |
|
|
| def evaluate_local_validation(self, max_samples): |
| """Evaluate both color embeddings on local validation dataset""" |
| print(f"\n{'='*60}") |
| print("Evaluating Local Validation Dataset") |
| print(" Color embeddings") |
| print(f"Max samples: {max_samples}") |
| print(f"{'='*60}") |
|
|
| local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df) |
| dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
| results = {} |
|
|
| |
| print("\nCOLOR EVALUATION") |
| print("=" * 50) |
|
|
| |
| print("\nExtracting text color embeddings...") |
| text_color_embeddings, text_colors = extract_color_model_embeddings( |
| self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples |
| ) |
| print(f" Text color embeddings shape: {text_color_embeddings.shape}") |
| text_color_metrics = compute_similarity_metrics(text_color_embeddings, text_colors) |
| text_color_class = self.evaluate_classification_performance( |
| text_color_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color" |
| ) |
| text_color_metrics.update(text_color_class) |
| results['text_color'] = text_color_metrics |
|
|
| del text_color_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| print("\nExtracting image color embeddings...") |
| image_color_embeddings, image_colors = extract_color_model_embeddings( |
| self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples |
| ) |
| print(f" Image color embeddings shape: {image_color_embeddings.shape}") |
| image_color_metrics = compute_similarity_metrics(image_color_embeddings, image_colors) |
| image_color_class = self.evaluate_classification_performance( |
| image_color_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color" |
| ) |
| image_color_metrics.update(image_color_class) |
| results['image_color'] = image_color_metrics |
|
|
| del image_color_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| os.makedirs(self.directory, exist_ok=True) |
| for key in ['text_color', 'image_color']: |
| results[key]['figure'].savefig( |
| f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", |
| dpi=300, |
| bbox_inches='tight', |
| ) |
| plt.close(results[key]['figure']) |
|
|
| return results |
|
|
|
|
| def evaluate_baseline_kaggle_marqo(self, max_samples=5000): |
| """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" |
| print(f"\n{'='*60}") |
| print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") |
| print(f"Max samples: {max_samples}") |
| print(f"{'='*60}") |
|
|
| |
| kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df) |
| if kaggle_dataset is None: |
| print("Failed to load KAGL dataset") |
| return None |
|
|
| |
| dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) |
|
|
| results = {} |
|
|
| |
| print("\nExtracting baseline text embeddings from KAGL Marqo...") |
| text_embeddings, text_colors, _ = extract_clip_embeddings( |
| self.baseline_model, self.baseline_processor, dataloader, self.device, |
| embedding_type='text', max_samples=max_samples |
| ) |
| print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
| text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors) |
|
|
| text_color_classification = self.evaluate_classification_performance( |
| text_embeddings, text_colors, "KAGL Marqo, text, color confusion matrix", "Color" |
| ) |
| text_color_metrics.update(text_color_classification) |
| results['text'] = { |
| 'color': text_color_metrics |
| } |
|
|
| |
| del text_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| print("\nExtracting baseline image embeddings from KAGL Marqo...") |
| image_embeddings, image_colors, _ = extract_clip_embeddings( |
| self.baseline_model, self.baseline_processor, dataloader, self.device, |
| embedding_type='image', max_samples=max_samples |
| ) |
| print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
| image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors) |
|
|
| image_color_classification = self.evaluate_classification_performance( |
| image_embeddings, image_colors, "KAGL Marqo, image, color confusion matrix", "Color" |
| ) |
| image_color_metrics.update(image_color_classification) |
| results['image'] = { |
| 'color': image_color_metrics |
| } |
|
|
| |
| del image_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| os.makedirs(self.directory, exist_ok=True) |
| for key in ['text', 'image']: |
| for subkey in ['color']: |
| figure = results[key][subkey]['figure'] |
| figure.savefig( |
| f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", |
| dpi=300, |
| bbox_inches='tight', |
| ) |
| plt.close(figure) |
|
|
| return results |
|
|
| def evaluate_baseline_local_validation(self, max_samples=5000): |
| """Evaluate baseline Fashion CLIP model on local validation dataset""" |
| print(f"\n{'='*60}") |
| print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") |
| print(f"Max samples: {max_samples}") |
| print(f"{'='*60}") |
|
|
| |
| local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df) |
| if local_dataset is None: |
| print("Failed to load local validation dataset") |
| return None |
|
|
| |
| dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
| results = {} |
|
|
| |
| print("\nExtracting baseline text embeddings from Local Validation...") |
| text_embeddings, text_colors, _ = extract_clip_embeddings( |
| self.baseline_model, self.baseline_processor, dataloader, self.device, |
| embedding_type='text', max_samples=max_samples |
| ) |
| print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
| text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors) |
|
|
| text_color_classification = self.evaluate_classification_performance( |
| text_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color" |
| ) |
| text_color_metrics.update(text_color_classification) |
| results['text'] = { |
| 'color': text_color_metrics |
| } |
|
|
| |
| del text_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| print("\nExtracting baseline image embeddings from Local Validation...") |
| image_embeddings, image_colors, _ = extract_clip_embeddings( |
| self.baseline_model, self.baseline_processor, dataloader, self.device, |
| embedding_type='image', max_samples=max_samples |
| ) |
| print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
| image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors) |
|
|
| image_color_classification = self.evaluate_classification_performance( |
| image_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color" |
| ) |
| image_color_metrics.update(image_color_classification) |
| results['image'] = { |
| 'color': image_color_metrics |
| } |
|
|
| |
| del image_embeddings |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| os.makedirs(self.directory, exist_ok=True) |
| for key in ['text', 'image']: |
| for subkey in ['color']: |
| figure = results[key][subkey]['figure'] |
| figure.savefig( |
| f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", |
| dpi=300, |
| bbox_inches='tight', |
| ) |
| plt.close(figure) |
|
|
| return results |
|
|
| def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name): |
| """Analyse baseline vs trained model performance.""" |
| print(f"\n{'='*60}") |
| print(f"ANALYSE: Baseline vs Trained - {dataset_name}") |
| print(f"{'='*60}") |
|
|
| comparisons = [] |
|
|
| |
| trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0) |
| baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0) |
| if trained_color_text_acc > 0 and baseline_color_text_acc > 0: |
| diff = baseline_color_text_acc - trained_color_text_acc |
| comparisons.append({ |
| 'type': 'Text Color', |
| 'trained': trained_color_text_acc, |
| 'baseline': baseline_color_text_acc, |
| 'diff': diff, |
| 'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)', |
| 'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)' |
| }) |
|
|
| |
| trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0) |
| baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0) |
| if trained_color_img_acc > 0 and baseline_color_img_acc > 0: |
| diff = baseline_color_img_acc - trained_color_img_acc |
| comparisons.append({ |
| 'type': 'Image Color', |
| 'trained': trained_color_img_acc, |
| 'baseline': baseline_color_img_acc, |
| 'diff': diff, |
| 'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)', |
| 'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)' |
| }) |
|
|
| return comparisons |
|
|
|
|
|
|
| if __name__ == "__main__": |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| directory = 'figures/confusion_matrices/cm_color' |
| max_samples = 10000 |
| local_max_samples = 10000 |
|
|
| evaluator = ColorEvaluator(device=device, directory=directory) |
|
|
| |
| print("\n" + "="*60) |
| print("Starting evaluation of Local Validation Dataset with Color embeddings") |
| print("="*60) |
| results_local = evaluator.evaluate_local_validation(max_samples=local_max_samples) |
|
|
| if results_local is not None: |
| print(f"\n{'='*60}") |
| print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") |
| print(f"{'='*60}") |
|
|
| print("\nCOLOR CLASSIFICATION RESULTS:") |
| print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") |
| print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") |
|
|
| |
| print("\n" + "="*60) |
| print("Starting evaluation of Baseline Fashion CLIP on Local Validation") |
| print("="*60) |
| results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=local_max_samples) |
|
|
| if results_baseline_local is not None: |
| print(f"\n{'='*60}") |
| print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") |
| print(f"{'='*60}") |
|
|
| print("\nCOLOR CLASSIFICATION RESULTS (Baseline):") |
| print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") |
| print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") |
|
|
|
|
| print(f"\nEvaluation completed! Check '{directory}/' for visualization files.") |
|
|