gap-clip / evaluation /sec51_color_model_eval.py
Leacb4's picture
Upload evaluation/sec51_color_model_eval.py with huggingface_hub
942267d verified
"""
Section 5.1 — Color Model Evaluation (Table 1)
===============================================
Evaluates the standalone 16D color model (ColorCLIP) on accuracy and
separation scores across:
- KAGL Marqo (external, 10k items, 46 colors)
- Local validation dataset (internal, 5k items, 11 colors)
Metrics reported match Table 1 in the paper:
- Text/image embedding NN accuracy
- Text/image embedding separation score (intra - inter class distance)
Compared against Fashion-CLIP baseline (patrickjohncyh/fashion-clip).
Run directly:
python sec51_color_model_eval.py
Paper reference: Section 5.1, Table 1.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import sys
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, accuracy_score
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')
# Ensure project root is importable when running this file directly.
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config import (
color_model_path,
color_emb_dim,
main_emb_dim,
)
from utils.datasets import (
load_kaggle_marqo_dataset,
load_local_validation_dataset,
collate_fn_filter_none,
)
from utils.embeddings import extract_clip_embeddings, extract_color_model_embeddings
from utils.metrics import (
compute_similarity_metrics,
predict_labels_from_embeddings,
create_confusion_matrix,
)
from utils.model_loader import load_color_model, load_baseline_fashion_clip
class ColorEvaluator:
"""Evaluate color 16 embeddings"""
def __init__(
self,
device='mps',
directory="figures/confusion_matrices/cm_color",
baseline_model=None,
baseline_processor=None,
color_model=None,
kaggle_raw_df=None,
local_raw_df=None,
):
self.device = torch.device(device)
self.directory = directory
self.color_emb_dim = color_emb_dim
self.main_emb_dim = main_emb_dim
self.kaggle_raw_df = kaggle_raw_df
self.local_raw_df = local_raw_df
os.makedirs(self.directory, exist_ok=True)
# Load baseline Fashion CLIP model (or reuse pre-loaded)
if baseline_model is not None and baseline_processor is not None:
self.baseline_model = baseline_model
self.baseline_processor = baseline_processor
else:
print("Loading baseline Fashion CLIP model...")
self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device)
print("Baseline Fashion CLIP model loaded successfully")
# Load specialized color model (or reuse pre-loaded)
if color_model is not None:
self.color_model = color_model
else:
self.color_model, _ = load_color_model(color_model_path, self.device)
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
"""
Evaluate classification performance and create confusion matrix.
Args:
embeddings: Embeddings
labels: True labels
embedding_type: Type of embeddings for display
label_type: Type of labels (Color)
"""
predictions = predict_labels_from_embeddings(embeddings, labels)
# Filter out None values from labels and predictions
valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions))
if label is not None and pred is not None]
if len(valid_indices) == 0:
print(f"Warning: No valid labels/predictions found (all are None)")
return {
'accuracy': 0.0,
'predictions': predictions,
'confusion_matrix': None,
'classification_report': None,
'figure': None,
}
filtered_labels = [labels[i] for i in valid_indices]
filtered_predictions = [predictions[i] for i in valid_indices]
accuracy = accuracy_score(filtered_labels, filtered_predictions)
fig, _, cm = create_confusion_matrix(
filtered_labels, filtered_predictions,
embedding_type,
label_type
)
unique_labels = sorted(list(set(filtered_labels)))
report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
return {
'accuracy': accuracy,
'predictions': predictions,
'confusion_matrix': cm,
'classification_report': report,
'figure': fig,
}
def evaluate_kaggle_marqo(self, max_samples):
"""Evaluate both color embeddings on KAGL Marqo dataset"""
print(f"\n{'='*60}")
print("Evaluating KAGL Marqo Dataset with Color embeddings")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df)
if kaggle_dataset is None:
print("Failed to load KAGL dataset")
return None
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
results = {}
# ========== EXTRACT COLOR MODEL EMBEDDINGS ==========
print("\nExtracting color model embeddings...")
text_full_embeddings, text_colors_full = extract_color_model_embeddings(
self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples
)
image_full_embeddings, image_colors_full = extract_color_model_embeddings(
self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples
)
text_color_metrics = compute_similarity_metrics(text_full_embeddings, text_colors_full)
text_color_class = self.evaluate_classification_performance(
text_full_embeddings, text_colors_full,
"KAGL Marqo, text, color confusion matrix", "Color",
)
text_color_metrics.update(text_color_class)
results['text_color'] = text_color_metrics
image_color_metrics = compute_similarity_metrics(image_full_embeddings, image_colors_full)
image_color_class = self.evaluate_classification_performance(
image_full_embeddings, image_colors_full,
"KAGL Marqo, image, color confusion matrix", "Color",
)
image_color_metrics.update(image_color_class)
results['image_color'] = image_color_metrics
del text_full_embeddings, image_full_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text_color', 'image_color']:
results[key]['figure'].savefig(
f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(results[key]['figure'])
return results
def evaluate_local_validation(self, max_samples):
"""Evaluate both color embeddings on local validation dataset"""
print(f"\n{'='*60}")
print("Evaluating Local Validation Dataset")
print(" Color embeddings")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df)
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
results = {}
# ========== COLOR EVALUATION ==========
print("\nCOLOR EVALUATION")
print("=" * 50)
# Text color embeddings
print("\nExtracting text color embeddings...")
text_color_embeddings, text_colors = extract_color_model_embeddings(
self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples
)
print(f" Text color embeddings shape: {text_color_embeddings.shape}")
text_color_metrics = compute_similarity_metrics(text_color_embeddings, text_colors)
text_color_class = self.evaluate_classification_performance(
text_color_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
)
text_color_metrics.update(text_color_class)
results['text_color'] = text_color_metrics
del text_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Image color embeddings
print("\nExtracting image color embeddings...")
image_color_embeddings, image_colors = extract_color_model_embeddings(
self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples
)
print(f" Image color embeddings shape: {image_color_embeddings.shape}")
image_color_metrics = compute_similarity_metrics(image_color_embeddings, image_colors)
image_color_class = self.evaluate_classification_performance(
image_color_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
)
image_color_metrics.update(image_color_class)
results['image_color'] = image_color_metrics
del image_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text_color', 'image_color']:
results[key]['figure'].savefig(
f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(results[key]['figure'])
return results
def evaluate_baseline_kaggle_marqo(self, max_samples=5000):
"""Evaluate baseline Fashion CLIP model on KAGL Marqo dataset"""
print(f"\n{'='*60}")
print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
# Load KAGL Marqo dataset
kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df)
if kaggle_dataset is None:
print("Failed to load KAGL dataset")
return None
# Create dataloader
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
results = {}
# Evaluate text embeddings
print("\nExtracting baseline text embeddings from KAGL Marqo...")
text_embeddings, text_colors, _ = extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type='text', max_samples=max_samples
)
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors)
text_color_classification = self.evaluate_classification_performance(
text_embeddings, text_colors, "KAGL Marqo, text, color confusion matrix", "Color"
)
text_color_metrics.update(text_color_classification)
results['text'] = {
'color': text_color_metrics
}
# Clear memory
del text_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate image embeddings
print("\nExtracting baseline image embeddings from KAGL Marqo...")
image_embeddings, image_colors, _ = extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type='image', max_samples=max_samples
)
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors)
image_color_classification = self.evaluate_classification_performance(
image_embeddings, image_colors, "KAGL Marqo, image, color confusion matrix", "Color"
)
image_color_metrics.update(image_color_classification)
results['image'] = {
'color': image_color_metrics
}
# Clear memory
del image_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text', 'image']:
for subkey in ['color']:
figure = results[key][subkey]['figure']
figure.savefig(
f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(figure)
return results
def evaluate_baseline_local_validation(self, max_samples=5000):
"""Evaluate baseline Fashion CLIP model on local validation dataset"""
print(f"\n{'='*60}")
print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
# Load local validation dataset
local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df)
if local_dataset is None:
print("Failed to load local validation dataset")
return None
# Create dataloader
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
results = {}
# Evaluate text embeddings
print("\nExtracting baseline text embeddings from Local Validation...")
text_embeddings, text_colors, _ = extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type='text', max_samples=max_samples
)
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors)
text_color_classification = self.evaluate_classification_performance(
text_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
)
text_color_metrics.update(text_color_classification)
results['text'] = {
'color': text_color_metrics
}
# Clear memory
del text_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate image embeddings
print("\nExtracting baseline image embeddings from Local Validation...")
image_embeddings, image_colors, _ = extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type='image', max_samples=max_samples
)
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors)
image_color_classification = self.evaluate_classification_performance(
image_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
)
image_color_metrics.update(image_color_classification)
results['image'] = {
'color': image_color_metrics
}
# Clear memory
del image_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text', 'image']:
for subkey in ['color']:
figure = results[key][subkey]['figure']
figure.savefig(
f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(figure)
return results
def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
"""Analyse baseline vs trained model performance."""
print(f"\n{'='*60}")
print(f"ANALYSE: Baseline vs Trained - {dataset_name}")
print(f"{'='*60}")
comparisons = []
# Text Color
trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
if trained_color_text_acc > 0 and baseline_color_text_acc > 0:
diff = baseline_color_text_acc - trained_color_text_acc
comparisons.append({
'type': 'Text Color',
'trained': trained_color_text_acc,
'baseline': baseline_color_text_acc,
'diff': diff,
'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)',
'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)'
})
# Image Color
trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
if trained_color_img_acc > 0 and baseline_color_img_acc > 0:
diff = baseline_color_img_acc - trained_color_img_acc
comparisons.append({
'type': 'Image Color',
'trained': trained_color_img_acc,
'baseline': baseline_color_img_acc,
'diff': diff,
'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)',
'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)'
})
return comparisons
if __name__ == "__main__":
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
directory = 'figures/confusion_matrices/cm_color'
max_samples = 10000
local_max_samples = 10000
evaluator = ColorEvaluator(device=device, directory=directory)
# Evaluate Local Validation Dataset
print("\n" + "="*60)
print("Starting evaluation of Local Validation Dataset with Color embeddings")
print("="*60)
results_local = evaluator.evaluate_local_validation(max_samples=local_max_samples)
if results_local is not None:
print(f"\n{'='*60}")
print("LOCAL VALIDATION DATASET EVALUATION SUMMARY")
print(f"{'='*60}")
print("\nCOLOR CLASSIFICATION RESULTS:")
print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}")
# Evaluate Baseline Fashion CLIP on Local Validation
print("\n" + "="*60)
print("Starting evaluation of Baseline Fashion CLIP on Local Validation")
print("="*60)
results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=local_max_samples)
if results_baseline_local is not None:
print(f"\n{'='*60}")
print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY")
print(f"{'='*60}")
print("\nCOLOR CLASSIFICATION RESULTS (Baseline):")
print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
print(f"\nEvaluation completed! Check '{directory}/' for visualization files.")