gap-clip / evaluation /sec52_category_model_eval.py
Leacb4's picture
Upload evaluation/sec52_category_model_eval.py with huggingface_hub
2dd8fc6 verified
"""
Section 5.2 — Category Model Evaluation (Table 2)
==================================================
Evaluates GAP-CLIP vs the Fashion-CLIP baseline on hierarchy (category)
classification using three datasets:
- Fashion-MNIST (10 categories)
- KAGL Marqo (external, real-world fashion e-commerce)
- Internal validation dataset
Produces hierarchy confusion matrices (text + image) for both models on each
dataset.
Metrics match Table 2 in the paper:
- Text/image embedding NN accuracy
- Text/image embedding separation score
Run directly:
python sec52_category_model_eval.py
Paper reference: Section 5.2, Table 2.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import difflib
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import normalize
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')
from config import (
ROOT_DIR,
main_model_path,
main_emb_dim,
hierarchy_model_path,
color_emb_dim,
hierarchy_emb_dim,
local_dataset_path,
column_local_image_path,
)
from utils.datasets import (
load_fashion_mnist_dataset,
)
from utils.embeddings import extract_clip_embeddings
from utils.metrics import (
compute_similarity_metrics,
compute_embedding_accuracy,
compute_centroid_accuracy,
predict_labels_from_embeddings,
create_confusion_matrix,
)
from utils.model_loader import load_gap_clip, load_baseline_fashion_clip
# ============================================================================
# 1b. KAGL Marqo utilities
# ============================================================================
class KaggleHierarchyDataset(Dataset):
"""KAGL Marqo dataset returning (image, description, color, hierarchy)."""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe.reset_index(drop=True)
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
image_data = row["image"]
if isinstance(image_data, dict) and "bytes" in image_data:
image = Image.open(BytesIO(image_data["bytes"])).convert("RGB")
elif hasattr(image_data, "convert"):
image = image_data.convert("RGB")
else:
image = Image.open(BytesIO(image_data)).convert("RGB")
image = self.transform(image)
description = str(row["text"])
color = str(row.get("baseColour", "unknown")).lower()
hierarchy = str(row["hierarchy"])
return image, description, color, hierarchy
def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
"""Load KAGL Marqo dataset with hierarchy labels derived from articleType.
Args:
raw_df: Pre-downloaded DataFrame to skip the HuggingFace download.
"""
if raw_df is not None:
df = raw_df.copy()
print(f"Using cached KAGL DataFrame for hierarchy evaluation: {len(df)} samples")
else:
from datasets import load_dataset
print("Loading KAGL Marqo dataset for hierarchy evaluation...")
dataset = load_dataset("Marqo/KAGL")
df = dataset["data"].to_pandas()
print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
# Use the most specific category column as hierarchy source
hierarchy_col = 'category2'
if hierarchy_col is None:
print("WARNING: No hierarchy column found in KAGL dataset")
return None
print(f"Using '{hierarchy_col}' as hierarchy source")
df = df.dropna(subset=["text", "image", hierarchy_col])
df["hierarchy"] = df[hierarchy_col].astype(str).str.strip()
# If hierarchy_classes provided, map KAGL types to model hierarchy classes
if hierarchy_classes:
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
mapped = []
for _, row in df.iterrows():
kagl_type = row["hierarchy"].lower()
matched = None
# Exact match
if kagl_type in hierarchy_classes_lower:
matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)]
else:
# Substring match
for h_class in hierarchy_classes:
h_lower = h_class.lower()
if h_lower in kagl_type or kagl_type in h_lower:
matched = h_class
break
if matched is None:
close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6)
if close:
matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])]
mapped.append(matched)
df["hierarchy"] = mapped
df = df.dropna(subset=["hierarchy"])
print(f"After hierarchy mapping: {len(df)} samples")
if len(df) > max_samples:
df = df.sample(n=max_samples, random_state=42)
print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: "
f"{sorted(df['hierarchy'].unique())}")
return KaggleHierarchyDataset(df)
# ============================================================================
# 1c. Local validation dataset utilities
# ============================================================================
class LocalHierarchyDataset(Dataset):
"""Local validation dataset returning (image, description, color, hierarchy)."""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe.reset_index(drop=True)
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
try:
img_path = row[column_local_image_path]
if not os.path.isabs(img_path):
img_path = os.path.join(ROOT_DIR, img_path)
image = Image.open(img_path).convert("RGB")
except Exception:
image = Image.new("RGB", (224, 224), color="gray")
image = self.transform(image)
description = str(row["text"])
color = str(row.get("color", "unknown"))
hierarchy = str(row["hierarchy"])
return image, description, color, hierarchy
def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
"""Load internal validation dataset with hierarchy labels.
Args:
raw_df: Pre-loaded DataFrame to skip CSV read.
"""
if raw_df is not None:
df = raw_df.copy()
print(f"Using cached local DataFrame for hierarchy evaluation: {len(df)} samples")
else:
print("Loading local validation dataset for hierarchy evaluation...")
df = pd.read_csv(local_dataset_path)
print(f"Dataset loaded: {len(df)} samples")
df = df.dropna(subset=[column_local_image_path, "hierarchy"])
df["hierarchy"] = df["hierarchy"].astype(str).str.strip()
df = df[df["hierarchy"].str.len() > 0]
if hierarchy_classes:
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
df["hierarchy_lower"] = df["hierarchy"].str.lower()
df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)]
# Restore proper casing from hierarchy_classes
case_map = {h.lower(): h for h in hierarchy_classes}
df["hierarchy"] = df["hierarchy_lower"].map(case_map)
df = df.drop(columns=["hierarchy_lower"])
print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes")
if len(df) > max_samples:
df = df.sample(n=max_samples, random_state=42)
print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}")
return LocalHierarchyDataset(df)
# ============================================================================
# 2. Evaluator
# ============================================================================
class CategoryModelEvaluator:
"""
Produces hierarchy confusion matrices for GAP-CLIP and the
baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets.
"""
def __init__(self, device='mps', directory='gap_clip_confusion_matrices',
gap_clip_model=None, gap_clip_processor=None,
baseline_model=None, baseline_processor=None,
hierarchy_classes=None,
kaggle_raw_df=None, local_raw_df=None):
self.device = torch.device(device) if isinstance(device, str) else device
self.directory = directory
self.kaggle_raw_df = kaggle_raw_df
self.local_raw_df = local_raw_df
self.color_emb_dim = color_emb_dim
self.hierarchy_emb_dim = hierarchy_emb_dim
self.main_emb_dim = main_emb_dim
self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim
os.makedirs(self.directory, exist_ok=True)
# --- hierarchy classes ---
if hierarchy_classes is not None:
self.hierarchy_classes = hierarchy_classes
print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes")
else:
print("Loading hierarchy classes from hierarchy model...")
if not os.path.exists(hierarchy_model_path):
raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found")
hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}")
self.validation_hierarchy_classes = self._load_validation_hierarchy_classes()
if self.validation_hierarchy_classes:
print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): "
f"{sorted(self.validation_hierarchy_classes)}")
else:
print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.")
self.validation_hierarchy_classes = self.hierarchy_classes
# --- load GAP-CLIP (accept pre-loaded or load from scratch) ---
if gap_clip_model is not None and gap_clip_processor is not None:
self.model = gap_clip_model
self.processor = gap_clip_processor
print("Using pre-loaded GAP-CLIP model")
else:
self.model, self.processor = load_gap_clip(main_model_path, self.device)
print("GAP-CLIP model loaded successfully")
# --- baseline Fashion-CLIP (accept pre-loaded or load from scratch) ---
if baseline_model is not None and baseline_processor is not None:
self.baseline_model = baseline_model
self.baseline_processor = baseline_processor
print("Using pre-loaded baseline Fashion-CLIP model")
else:
self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device)
print("Baseline Fashion-CLIP model loaded successfully")
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _load_validation_hierarchy_classes(self):
if not os.path.exists(local_dataset_path):
print(f"Validation dataset not found at {local_dataset_path}")
return []
try:
df = pd.read_csv(local_dataset_path)
except Exception as exc:
print(f"Failed to read validation dataset: {exc}")
return []
if 'hierarchy' not in df.columns:
print("Validation dataset does not contain 'hierarchy' column.")
return []
hierarchies = df['hierarchy'].dropna().astype(str).str.strip()
hierarchies = [h for h in hierarchies if h]
return sorted(set(hierarchies))
def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8):
"""
Build one shared Fashion-MNIST dataset/dataloader to ensure every model
is evaluated on the exact same items.
"""
target_classes = self.validation_hierarchy_classes or self.hierarchy_classes
fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes)
dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
hierarchy_counts = defaultdict(int)
if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping:
for _, row in fashion_dataset.dataframe.iterrows():
lid = int(row['label'])
hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1
return fashion_dataset, dataloader, dict(hierarchy_counts)
@staticmethod
def _count_labels(labels):
counts = defaultdict(int)
for label in labels:
counts[label] += 1
return dict(counts)
def _validate_label_distribution(self, labels, expected_counts, context):
observed = self._count_labels(labels)
if observed != expected_counts:
raise ValueError(
f"Label distribution mismatch in {context}. "
f"Expected {expected_counts}, observed {observed}"
)
# ------------------------------------------------------------------
# embedding extraction (delegates to shared utils)
# ------------------------------------------------------------------
def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
"""Full 512D embeddings from GAP-CLIP (text or image)."""
return extract_clip_embeddings(
self.model, self.processor, dataloader, self.device,
embedding_type=embedding_type, max_samples=max_samples,
desc=f"GAP-CLIP {embedding_type} embeddings",
)
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
"""L2-normalised embeddings from baseline Fashion-CLIP."""
return extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type=embedding_type, max_samples=max_samples,
desc=f"Baseline {embedding_type} embeddings",
)
def predict_labels_nearest_neighbor(self, embeddings, labels):
"""
Predict labels using 1-NN on the same embedding set.
This matches the accuracy logic used in the evaluation pipeline.
"""
similarities = cosine_similarity(embeddings)
preds = []
for i in range(len(embeddings)):
sims = similarities[i].copy()
sims[i] = -1.0
nearest_neighbor_idx = int(np.argmax(sims))
preds.append(labels[nearest_neighbor_idx])
return preds
# ------------------------------------------------------------------
# image + text ensemble
# ------------------------------------------------------------------
def _compute_img_centroids(self, embeddings, labels):
emb_norm = normalize(embeddings, norm='l2')
centroids = {}
for label in sorted(set(labels)):
idx = [i for i, l in enumerate(labels) if l == label]
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
return centroids
def predict_labels_image_ensemble(self, img_embeddings, labels,
text_protos, cls_names, alpha=0.5):
"""Combine image centroids (512D) with text prototypes (512D)."""
img_norm = normalize(img_embeddings, norm='l2')
img_centroids = self._compute_img_centroids(img_norm, labels)
centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0)
preds = []
for i in range(len(img_norm)):
v = img_norm[i:i + 1]
sim_img = cosine_similarity(v, centroid_mat)[0]
sim_txt = cosine_similarity(v, text_protos)[0]
scores = alpha * sim_img + (1 - alpha) * sim_txt
preds.append(cls_names[int(np.argmax(scores))])
return preds
# ------------------------------------------------------------------
# confusion matrix & classification report
# ------------------------------------------------------------------
def evaluate_classification_performance(self, embeddings, labels,
embedding_type="Embeddings",
label_type="Label",
method="nn"):
if method == "nn":
preds = self.predict_labels_nearest_neighbor(embeddings, labels)
elif method == "centroid":
preds = predict_labels_from_embeddings(embeddings, labels)
else:
raise ValueError(f"Unknown classification method: {method}")
acc = accuracy_score(labels, preds)
unique_labels = sorted(set(labels))
fig, _, cm = create_confusion_matrix(
labels, preds,
f"{embedding_type} - {label_type} Classification ({method.upper()})",
label_type,
)
report = classification_report(labels, preds, labels=unique_labels,
target_names=unique_labels, output_dict=True)
return {
'accuracy': acc,
'predictions': preds,
'confusion_matrix': cm,
'labels': unique_labels,
'classification_report': report,
'figure': fig,
}
def save_confusion_matrix_table(self, cm, labels, output_csv_path):
"""
Save confusion matrix values with per-row totals to CSV for auditing.
"""
cm_df = pd.DataFrame(cm, index=labels, columns=labels)
cm_df["row_total"] = cm_df.sum(axis=1)
cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()]
cm_df.to_csv(output_csv_path)
# ==================================================================
# 3. GAP-CLIP evaluation on Fashion-MNIST
# ==================================================================
def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
print(f"\n{'=' * 60}")
print("Evaluating GAP-CLIP on Fashion-MNIST")
print(" Hierarchy embeddings (dims 16-79)")
print(f" Max samples: {max_samples}")
print(f"{'=' * 60}")
if dataloader is None:
fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
expected_counts = expected_counts or dataset_counts
else:
fashion_dataset = getattr(dataloader, "dataset", None)
if expected_counts is None:
raise ValueError("expected_counts must be provided when using a custom dataloader.")
if fashion_dataset is not None and len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping:
print(f"\nHierarchy distribution in dataset:")
for h in sorted(expected_counts):
print(f" {h}: {expected_counts[h]} samples")
results = {}
# --- full 512D embeddings (text & image) ---
print("\nExtracting full 512-dimensional GAP-CLIP embeddings...")
text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text")
self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image")
print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}")
# --- TEXT: hierarchy on specialized 64D (dims 16-79) ---
print("\n--- GAP-CLIP TEXT HIERARCHY (dims 16-79) ---")
text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
print(f" Specialized text hierarchy shape: {text_hier_spec.shape}")
text_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
text_class = self.evaluate_classification_performance(
text_hier_spec, text_hier,
"GAP-CLIP Text Hierarchy (64D)", "Hierarchy",
method="nn",
)
text_metrics.update(text_class)
results['text_hierarchy'] = text_metrics
# --- IMAGE: 64D vs 512D + ensemble ---
print("\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---")
img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
print(f" Specialized image hierarchy shape: {img_hier_spec.shape}")
print(" Testing specialized 64D...")
spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
spec_class = self.evaluate_classification_performance(
img_hier_spec, img_hier,
"GAP-CLIP Image Hierarchy (64D)", "Hierarchy",
method="nn",
)
print(" Testing full 512D...")
full_metrics = compute_similarity_metrics(img_full, img_hier)
full_class = self.evaluate_classification_performance(
img_full, img_hier,
"GAP-CLIP Image Hierarchy (512D full)", "Hierarchy",
method="nn",
)
if full_class['accuracy'] >= spec_class['accuracy']:
print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%")
img_metrics, img_class = full_metrics, full_class
else:
print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%")
img_metrics, img_class = spec_metrics, spec_class
# --- ensemble image + text prototypes ---
print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...")
cls_names = sorted(set(img_hier))
prompts = [f"a photo of a {c}" for c in cls_names]
text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
with torch.no_grad():
txt_feats = self.model.get_text_features(**text_inputs)
txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
text_protos = txt_feats.cpu().numpy()
ensemble_preds = self.predict_labels_image_ensemble(
img_full, img_hier, text_protos, cls_names, alpha=0.7,
)
ensemble_acc = accuracy_score(img_hier, ensemble_preds)
print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%")
img_metrics.update(img_class)
img_metrics['ensemble_accuracy'] = ensemble_acc
results['image_hierarchy'] = img_metrics
# --- save confusion matrix figures ---
for key in ['text_hierarchy', 'image_hierarchy']:
fig = results[key]['figure']
fig.savefig(
os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['confusion_matrix'],
results[key]['labels'],
os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"),
)
plt.close(fig)
del text_full, img_full, text_hier_spec, img_hier_spec
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
# ==================================================================
# 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST
# ==================================================================
def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
print(f"\n{'=' * 60}")
print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST")
print(f" Max samples: {max_samples}")
print(f"{'=' * 60}")
if dataloader is None:
_, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
expected_counts = expected_counts or dataset_counts
elif expected_counts is None:
raise ValueError("expected_counts must be provided when using a custom dataloader.")
results = {}
# --- text ---
print("\nExtracting baseline text embeddings...")
text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
self._validate_label_distribution(text_hier, expected_counts, "baseline text")
print(f" Baseline text shape: {text_emb.shape}")
text_metrics = compute_similarity_metrics(text_emb, text_hier)
text_class = self.evaluate_classification_performance(
text_emb, text_hier,
"Baseline Fashion-CLIP Text - Hierarchy", "Hierarchy",
method="nn",
)
text_metrics.update(text_class)
results['text'] = {'hierarchy': text_metrics}
del text_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- image ---
print("\nExtracting baseline image embeddings...")
img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
self._validate_label_distribution(img_hier, expected_counts, "baseline image")
print(f" Baseline image shape: {img_emb.shape}")
img_metrics = compute_similarity_metrics(img_emb, img_hier)
img_class = self.evaluate_classification_performance(
img_emb, img_hier,
"Baseline Fashion-CLIP Image - Hierarchy", "Hierarchy",
method="nn",
)
img_metrics.update(img_class)
results['image'] = {'hierarchy': img_metrics}
del img_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
for key in ['text', 'image']:
fig = results[key]['hierarchy']['figure']
fig.savefig(
os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['hierarchy']['confusion_matrix'],
results[key]['hierarchy']['labels'],
os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"),
)
plt.close(fig)
return results
# ==================================================================
# 5. Generic dataset evaluation (KAGL Marqo / Internal)
# ==================================================================
def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000):
"""Evaluate GAP-CLIP hierarchy performance on any dataset."""
print(f"\n{'=' * 60}")
print(f"Evaluating GAP-CLIP on {dataset_name}")
print(f" Hierarchy embeddings (dims 16-79)")
print(f"{'=' * 60}")
results = {}
# --- text hierarchy (64D specialized) ---
print("\nExtracting GAP-CLIP text embeddings...")
text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}")
text_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
text_class = self.evaluate_classification_performance(
text_hier_spec, text_hier,
f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
text_metrics.update(text_class)
results['text_hierarchy'] = text_metrics
# --- image hierarchy (best of 64D vs 512D) ---
print("\nExtracting GAP-CLIP image embeddings...")
img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
spec_class = self.evaluate_classification_performance(
img_hier_spec, img_hier,
f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn",
)
full_metrics = compute_similarity_metrics(img_full, img_hier)
full_class = self.evaluate_classification_performance(
img_full, img_hier,
f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn",
)
if full_class['accuracy'] >= spec_class['accuracy']:
print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%")
img_metrics, img_class = full_metrics, full_class
else:
print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%")
img_metrics, img_class = spec_metrics, spec_class
img_metrics.update(img_class)
results['image_hierarchy'] = img_metrics
# --- save confusion matrices ---
prefix = dataset_name.lower().replace(" ", "_")
for key in ['text_hierarchy', 'image_hierarchy']:
fig = results[key]['figure']
fig.savefig(
os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['confusion_matrix'], results[key]['labels'],
os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"),
)
plt.close(fig)
del text_full, img_full, text_hier_spec, img_hier_spec
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000):
"""Evaluate baseline Fashion-CLIP hierarchy performance on any dataset."""
print(f"\n{'=' * 60}")
print(f"Evaluating Baseline Fashion-CLIP on {dataset_name}")
print(f"{'=' * 60}")
results = {}
# --- text ---
print("\nExtracting baseline text embeddings...")
text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
print(f" Baseline text shape: {text_emb.shape}")
text_metrics = compute_similarity_metrics(text_emb, text_hier)
text_class = self.evaluate_classification_performance(
text_emb, text_hier,
f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
text_metrics.update(text_class)
results['text'] = {'hierarchy': text_metrics}
del text_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- image ---
print("\nExtracting baseline image embeddings...")
img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
print(f" Baseline image shape: {img_emb.shape}")
img_metrics = compute_similarity_metrics(img_emb, img_hier)
img_class = self.evaluate_classification_performance(
img_emb, img_hier,
f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
img_metrics.update(img_class)
results['image'] = {'hierarchy': img_metrics}
del img_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
prefix = dataset_name.lower().replace(" ", "_")
for key in ['text', 'image']:
fig = results[key]['hierarchy']['figure']
fig.savefig(
os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['hierarchy']['confusion_matrix'],
results[key]['hierarchy']['labels'],
os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.csv"),
)
plt.close(fig)
return results
# ==================================================================
# 6. Full evaluation across all datasets
# ==================================================================
def run_full_evaluation(self, max_samples=10000, batch_size=8):
"""Run hierarchy evaluation on all 3 datasets for both models."""
all_results = {}
# --- Fashion-MNIST ---
shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist(
max_samples=max_samples, batch_size=batch_size,
)
all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist(
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
)
all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist(
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
)
# --- KAGL Marqo ---
try:
kaggle_dataset = load_kaggle_marqo_with_hierarchy(
max_samples=max_samples,
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
raw_df=self.kaggle_raw_df,
)
if kaggle_dataset is not None and len(kaggle_dataset) > 0:
kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_results['kaggle_gap'] = self.evaluate_gap_clip_generic(
kaggle_dataloader, "KAGL Marqo", max_samples,
)
all_results['kaggle_baseline'] = self.evaluate_baseline_generic(
kaggle_dataloader, "KAGL Marqo", max_samples,
)
else:
print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.")
except Exception as e:
print(f"WARNING: Could not evaluate on KAGL Marqo: {e}")
# --- Internal (local validation) ---
try:
local_dataset = load_local_validation_with_hierarchy(
max_samples=max_samples,
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
raw_df=self.local_raw_df,
)
if local_dataset is not None and len(local_dataset) > 0:
local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_results['local_gap'] = self.evaluate_gap_clip_generic(
local_dataloader, "Internal", max_samples,
)
all_results['local_baseline'] = self.evaluate_baseline_generic(
local_dataloader, "Internal", max_samples,
)
else:
print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.")
except Exception as e:
print(f"WARNING: Could not evaluate on internal dataset: {e}")
# --- Print summary ---
print(f"\n{'=' * 70}")
print("CATEGORY MODEL EVALUATION SUMMARY")
print(f"{'=' * 70}")
for dataset_key, label in [
('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'),
('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'),
('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'),
('kaggle_baseline', 'KAGL Marqo (Baseline)'),
('local_gap', 'Internal (GAP-CLIP)'),
('local_baseline', 'Internal (Baseline)'),
]:
if dataset_key not in all_results:
continue
res = all_results[dataset_key]
print(f"\n{label}:")
if 'text_hierarchy' in res:
t = res['text_hierarchy']
i = res['image_hierarchy']
print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
elif 'text' in res:
t = res['text']['hierarchy']
i = res['image']['hierarchy']
print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
return all_results
# ============================================================================
# 7. Main
# ============================================================================
if __name__ == "__main__":
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
directory = 'gap_clip_confusion_matrices'
max_samples = 10000
evaluator = CategoryModelEvaluator(device=device, directory=directory)
evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)