gap-clip / evaluation /sec533_clip_nn_accuracy.py
Leacb4's picture
Upload evaluation/sec533_clip_nn_accuracy.py with huggingface_hub
26c2a79 verified
"""
Section 5.3.3 Nearest-Neighbour Classification Accuracy (Table 3)
==================================================================
Evaluates the full GAP-CLIP embedding on three datasets and compares with the
patrickjohncyh/fashion-clip baseline — **color and hierarchy**.
- Fashion-MNIST (public benchmark, 10 clothing categories)
- KAGL Marqo HuggingFace dataset (diverse fashion, colour + category labels)
- Internal local validation set (50 k images)
For each dataset the ``ColorHierarchyEvaluator`` class extracts:
* **Color slice** (dims 0–15): nearest-neighbour accuracy per colour class.
* **Hierarchy slice** (dims 16–79): nearest-neighbour accuracy per category,
plus 64-D vs 512-D comparison and image + text-prototype ensemble.
Results feed directly into **Table 3** of the paper.
The hierarchy mapping for Kaggle uses the same approach as in
``sec52_category_model_eval.py`` (exact match -> substring -> fuzzy on
``category2``).
See also:
- Section 5.1 (``sec51_color_model_eval.py``) – standalone colour model
- Section 5.2 (``sec52_category_model_eval.py``) – confusion-matrix analysis
- Section 5.3.6 (``sec536_embedding_structure.py``) – embedding-structure validation
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import difflib
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from collections import defaultdict
from io import BytesIO
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
warnings.filterwarnings('ignore')
from config import (
ROOT_DIR,
color_emb_dim,
column_local_image_path,
hierarchy_emb_dim,
hierarchy_model_path,
local_dataset_path,
main_emb_dim,
main_model_path,
)
from utils.datasets import (
FashionMNISTDataset,
LocalDataset,
collate_fn_filter_none,
load_fashion_mnist_dataset,
load_local_validation_dataset,
)
from utils.embeddings import extract_clip_embeddings
from utils.metrics import (
compute_similarity_metrics,
compute_centroid_accuracy,
predict_labels_from_embeddings,
create_confusion_matrix,
)
from utils.model_loader import load_gap_clip, load_baseline_fashion_clip
from training.hierarchy_model import HierarchyExtractor
# ---------------------------------------------------------------------------
# Hierarchy label normalisation (same as sec536_embedding_structure.py)
# Maps long internal taxonomy strings -> clean labels like "top", "pant", etc.
# ---------------------------------------------------------------------------
NORMALIZED_HIERARCHY_CLASSES = [
"accessories", "bodysuits", "bras", "coat", "dress", "jacket",
"legging", "pant", "polo", "shirt", "shoes", "short", "skirt",
"socks", "sweater", "swimwear", "top", "underwear",
]
_HIERARCHY_EXTRACTOR = HierarchyExtractor(NORMALIZED_HIERARCHY_CLASSES, verbose=False)
_SYNONYMS = {
"t-shirt/top": "top", "top": "top", "tee": "top", "t-shirt": "top",
"shirt": "shirt", "shirts": "shirt",
"pullover": "sweater", "sweater": "sweater",
"coat": "coat", "jacket": "jacket", "outerwear": "coat", "outer": "coat",
"trouser": "pant", "trousers": "pant", "pants": "pant", "pant": "pant", "jeans": "pant",
"dress": "dress", "skirt": "skirt",
"shorts": "short", "short": "short",
"sandal": "shoes", "sneaker": "shoes", "ankle boot": "shoes",
"shoe": "shoes", "shoes": "shoes", "flip flops": "shoes",
"footwear": "shoes", "shoe accessories": "shoes", "boots": "shoes",
"bag": "accessories", "bags": "accessories",
"accessory": "accessories", "accessories": "accessories",
"belts": "accessories", "eyewear": "accessories",
"jewellery": "accessories", "jewelry": "accessories",
"headwear": "accessories", "wallets": "accessories",
"watches": "accessories", "mufflers": "accessories",
"scarves": "accessories", "stoles": "accessories",
"ties": "accessories", "sunglasses": "accessories",
"scarf & tie": "accessories", "scarf/tie": "accessories", "belt": "accessories",
"topwear": "top", "bottomwear": "pant",
"innerwear": "underwear", "loungewear and nightwear": "underwear",
"saree": "dress",
}
_EXTRA_KEYWORDS = [
("capri", "pant"), ("denim", "pant"), ("skinny", "pant"),
("boyfriend", "pant"), ("graphic", "top"), ("longsleeve", "top"),
("leather", "jacket"),
]
def normalize_hierarchy_label(raw_label: str) -> str:
"""Map any hierarchy string to a clean normalised label."""
label = str(raw_label).strip().lower()
exact = _SYNONYMS.get(label)
if exact is not None:
return exact
result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label)
if result:
return result
for keyword, category in _EXTRA_KEYWORDS:
if keyword in label:
return category
return label
# ============================================================================
# 1. Dataset utilities (hierarchy mapping matches sec52)
# ============================================================================
class KaggleHierarchyDataset(Dataset):
"""KAGL Marqo dataset returning (image, description, color, hierarchy)."""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe.reset_index(drop=True)
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
image_data = row["image"]
if isinstance(image_data, dict) and "bytes" in image_data:
image = Image.open(BytesIO(image_data["bytes"])).convert("RGB")
elif hasattr(image_data, "convert"):
image = image_data.convert("RGB")
else:
image = Image.open(BytesIO(image_data)).convert("RGB")
image = self.transform(image)
description = str(row["text"])
color = str(row.get("baseColour", "unknown")).lower()
hierarchy = str(row["hierarchy"])
return image, description, color, hierarchy
def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
"""Load KAGL Marqo dataset with hierarchy labels derived from category2.
Mapping: exact match -> substring match -> fuzzy match (same as sec52).
"""
if raw_df is not None:
df = raw_df.copy()
print(f"Using cached KAGL DataFrame: {len(df)} samples")
else:
from datasets import load_dataset
print("Loading KAGL Marqo dataset...")
dataset = load_dataset("Marqo/KAGL")
df = dataset["data"].to_pandas()
print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
hierarchy_col = 'category2'
print(f"Using '{hierarchy_col}' as hierarchy source")
df = df.dropna(subset=["text", "image", hierarchy_col])
df["hierarchy"] = df[hierarchy_col].astype(str).str.strip()
# Normalise every category2 value through the synonym/extractor pipeline
df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label)
if hierarchy_classes:
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
mapped = []
for _, row in df.iterrows():
kagl_type = row["hierarchy"].lower()
matched = None
# Exact match (after normalisation most will hit here)
if kagl_type in hierarchy_classes_lower:
matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)]
else:
# Substring match
for h_class in hierarchy_classes:
h_lower = h_class.lower()
if h_lower in kagl_type or kagl_type in h_lower:
matched = h_class
break
if matched is None:
close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6)
if close:
matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])]
mapped.append(matched)
df["hierarchy"] = mapped
df = df.dropna(subset=["hierarchy"])
print(f"After hierarchy mapping: {len(df)} samples")
# Normalise color column
if "baseColour" in df.columns:
df["baseColour"] = df["baseColour"].fillna("unknown").astype(str).str.lower().str.replace("grey", "gray")
else:
df["baseColour"] = "unknown"
df = df.dropna(subset=["text", "image"])
if len(df) > max_samples:
df = df.sample(n=max_samples, random_state=42)
print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: "
f"{sorted(df['hierarchy'].unique())}")
return KaggleHierarchyDataset(df)
class LocalHierarchyDataset(Dataset):
"""Local validation dataset returning (image, description, color, hierarchy)."""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe.reset_index(drop=True)
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
try:
img_path = row[column_local_image_path]
if not os.path.isabs(img_path):
img_path = os.path.join(ROOT_DIR, img_path)
image = Image.open(img_path).convert("RGB")
except Exception:
image = Image.new("RGB", (224, 224), color="gray")
image = self.transform(image)
description = str(row["text"])
color = str(row.get("color", "unknown"))
hierarchy = str(row["hierarchy"])
return image, description, color, hierarchy
def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
"""Load internal validation dataset with hierarchy labels."""
if raw_df is not None:
df = raw_df.copy()
print(f"Using cached local DataFrame: {len(df)} samples")
else:
print("Loading local validation dataset...")
df = pd.read_csv(local_dataset_path)
print(f"Dataset loaded: {len(df)} samples")
df = df.dropna(subset=[column_local_image_path, "hierarchy"])
df["hierarchy"] = df["hierarchy"].astype(str).str.strip()
df = df[df["hierarchy"].str.len() > 0]
# Normalise raw taxonomy strings to clean labels
df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label)
if hierarchy_classes:
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
df["hierarchy_lower"] = df["hierarchy"].str.lower()
df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)]
case_map = {h.lower(): h for h in hierarchy_classes}
df["hierarchy"] = df["hierarchy_lower"].map(case_map)
df = df.drop(columns=["hierarchy_lower"])
print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes")
if len(df) > max_samples:
df = df.sample(n=max_samples, random_state=42)
print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}")
return LocalHierarchyDataset(df)
# ============================================================================
# 2. Evaluator
# ============================================================================
class ColorHierarchyEvaluator:
"""
Evaluates color and hierarchy NN classification accuracy for GAP-CLIP
and the baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and the
internal validation dataset.
"""
def __init__(self, device='mps', directory='main_model_analysis',
gap_clip_model=None, gap_clip_processor=None,
baseline_model=None, baseline_processor=None,
hierarchy_classes=None,
kaggle_raw_df=None, local_raw_df=None):
self.device = torch.device(device) if isinstance(device, str) else device
self.directory = directory
self.kaggle_raw_df = kaggle_raw_df
self.local_raw_df = local_raw_df
self.color_emb_dim = color_emb_dim
self.hierarchy_emb_dim = hierarchy_emb_dim
self.main_emb_dim = main_emb_dim
self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim
os.makedirs(self.directory, exist_ok=True)
# --- hierarchy classes ---
if hierarchy_classes is not None:
self.hierarchy_classes = hierarchy_classes
print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes")
else:
print("Loading hierarchy classes from hierarchy model...")
if not os.path.exists(hierarchy_model_path):
raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found")
hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}")
self.validation_hierarchy_classes = self._load_validation_hierarchy_classes()
if self.validation_hierarchy_classes:
print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): "
f"{sorted(self.validation_hierarchy_classes)}")
else:
print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.")
self.validation_hierarchy_classes = self.hierarchy_classes
# --- load GAP-CLIP ---
if gap_clip_model is not None and gap_clip_processor is not None:
self.model = gap_clip_model
self.processor = gap_clip_processor
print("Using pre-loaded GAP-CLIP model")
else:
self.model, self.processor = load_gap_clip(main_model_path, self.device)
print("GAP-CLIP model loaded successfully")
# --- baseline Fashion-CLIP ---
if baseline_model is not None and baseline_processor is not None:
self.baseline_model = baseline_model
self.baseline_processor = baseline_processor
print("Using pre-loaded baseline Fashion-CLIP model")
else:
self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device)
print("Baseline Fashion-CLIP model loaded successfully")
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _load_validation_hierarchy_classes(self):
"""Load hierarchy classes from local CSV, normalised to clean labels."""
if not os.path.exists(local_dataset_path):
print(f"Validation dataset not found at {local_dataset_path}")
return []
try:
df = pd.read_csv(local_dataset_path)
except Exception as exc:
print(f"Failed to read validation dataset: {exc}")
return []
if 'hierarchy' not in df.columns:
print("Validation dataset does not contain 'hierarchy' column.")
return []
raw = df['hierarchy'].dropna().astype(str).str.strip()
normalised = sorted(set(
normalize_hierarchy_label(h) for h in raw if h
))
# Keep only labels that belong to the known set
normalised = [h for h in normalised if h in NORMALIZED_HIERARCHY_CLASSES]
print(f"Normalised validation hierarchy classes: {normalised}")
return normalised
def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8):
"""Build one shared Fashion-MNIST dataset/dataloader.
Uses NORMALIZED_HIERARCHY_CLASSES so that Fashion-MNIST labels are
mapped to clean short names (top, pant, shoes, sweater, coat, …).
"""
target_classes = self.validation_hierarchy_classes or NORMALIZED_HIERARCHY_CLASSES
fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes)
# Normalise whatever label_mapping produced (e.g. "Coat" -> "coat")
if fashion_dataset.label_mapping:
fashion_dataset.label_mapping = {
k: normalize_hierarchy_label(v) if v else v
for k, v in fashion_dataset.label_mapping.items()
}
dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
hierarchy_counts = defaultdict(int)
if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping:
for _, row in fashion_dataset.dataframe.iterrows():
lid = int(row['label'])
hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1
return fashion_dataset, dataloader, dict(hierarchy_counts)
@staticmethod
def _count_labels(labels):
counts = defaultdict(int)
for label in labels:
counts[label] += 1
return dict(counts)
def _validate_label_distribution(self, labels, expected_counts, context):
observed = self._count_labels(labels)
if observed != expected_counts:
raise ValueError(
f"Label distribution mismatch in {context}. "
f"Expected {expected_counts}, observed {observed}"
)
# ------------------------------------------------------------------
# embedding extraction
# ------------------------------------------------------------------
def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
"""Full 512D embeddings from GAP-CLIP."""
return extract_clip_embeddings(
self.model, self.processor, dataloader, self.device,
embedding_type=embedding_type, max_samples=max_samples,
desc=f"GAP-CLIP {embedding_type} embeddings",
)
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
"""L2-normalised embeddings from baseline Fashion-CLIP."""
return extract_clip_embeddings(
self.baseline_model, self.baseline_processor, dataloader, self.device,
embedding_type=embedding_type, max_samples=max_samples,
desc=f"Baseline {embedding_type} embeddings",
)
# ------------------------------------------------------------------
# prediction methods
# ------------------------------------------------------------------
def predict_labels_nearest_neighbor(self, embeddings, labels):
"""Predict labels using 1-NN on the same embedding set."""
similarities = cosine_similarity(embeddings)
preds = []
for i in range(len(embeddings)):
sims = similarities[i].copy()
sims[i] = -1.0
nearest_neighbor_idx = int(np.argmax(sims))
preds.append(labels[nearest_neighbor_idx])
return preds
def _compute_img_centroids(self, embeddings, labels):
emb_norm = normalize(embeddings, norm='l2')
centroids = {}
for label in sorted(set(labels)):
idx = [i for i, l in enumerate(labels) if l == label]
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
return centroids
def predict_labels_image_ensemble(self, img_embeddings, labels,
text_protos, cls_names, alpha=0.5):
"""Combine image centroids (512D) with text prototypes (512D)."""
img_norm = normalize(img_embeddings, norm='l2')
img_centroids = self._compute_img_centroids(img_norm, labels)
centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0)
preds = []
for i in range(len(img_norm)):
v = img_norm[i:i + 1]
sim_img = cosine_similarity(v, centroid_mat)[0]
sim_txt = cosine_similarity(v, text_protos)[0]
scores = alpha * sim_img + (1 - alpha) * sim_txt
preds.append(cls_names[int(np.argmax(scores))])
return preds
# ------------------------------------------------------------------
# classification evaluation
# ------------------------------------------------------------------
def evaluate_classification_performance(self, embeddings, labels,
embedding_type="Embeddings",
label_type="Hierarchy",
method="nn"):
if method == "nn":
preds = self.predict_labels_nearest_neighbor(embeddings, labels)
elif method == "centroid":
preds = predict_labels_from_embeddings(embeddings, labels)
else:
raise ValueError(f"Unknown classification method: {method}")
acc = accuracy_score(labels, preds)
unique_labels = sorted(set(labels))
fig, _, cm = create_confusion_matrix(
labels, preds,
f"{embedding_type} - {label_type} Classification ({method.upper()})",
label_type,
)
report = classification_report(labels, preds, labels=unique_labels,
target_names=unique_labels, output_dict=True)
return {
'accuracy': acc,
'predictions': preds,
'confusion_matrix': cm,
'labels': unique_labels,
'classification_report': report,
'figure': fig,
}
def save_confusion_matrix_table(self, cm, labels, output_csv_path):
cm_df = pd.DataFrame(cm, index=labels, columns=labels)
cm_df["row_total"] = cm_df.sum(axis=1)
cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()]
cm_df.to_csv(output_csv_path)
# ==================================================================
# 3. GAP-CLIP evaluation on Fashion-MNIST (hierarchy only — no color)
# ==================================================================
def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
print(f"\n{'=' * 60}")
print("Evaluating GAP-CLIP on Fashion-MNIST (Hierarchy only)")
print(f" Hierarchy embeddings (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})")
print(f" Max samples: {max_samples}")
print(f"{'=' * 60}")
if dataloader is None:
fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
expected_counts = expected_counts or dataset_counts
else:
if expected_counts is None:
raise ValueError("expected_counts must be provided when using a custom dataloader.")
results = {}
# --- full 512D embeddings (text & image) ---
print("\nExtracting full 512-dimensional GAP-CLIP embeddings...")
text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text")
self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image")
print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}")
# ===== HIERARCHY (dims 16-79) =====
print(f"\n--- GAP-CLIP TEXT HIERARCHY (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1}) ---")
text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim]
print(f" Specialized text hierarchy shape: {text_hier_spec.shape}")
text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
text_hier_class = self.evaluate_classification_performance(
text_hier_spec, text_hier, "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", method="nn",
)
text_hier_metrics.update(text_hier_class)
results['text_hierarchy'] = text_hier_metrics
# IMAGE: 64D vs 512D
print(f"\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---")
img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim]
print(f" Specialized image hierarchy shape: {img_hier_spec.shape}")
print(" Testing specialized 64D...")
spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
spec_class = self.evaluate_classification_performance(
img_hier_spec, img_hier, "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", method="nn",
)
print(" Testing full 512D...")
full_metrics = compute_similarity_metrics(img_full, img_hier)
full_class = self.evaluate_classification_performance(
img_full, img_hier, "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", method="nn",
)
if full_class['accuracy'] >= spec_class['accuracy']:
print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%")
img_hier_metrics, img_hier_class = full_metrics, full_class
else:
print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%")
img_hier_metrics, img_hier_class = spec_metrics, spec_class
# ensemble image + text prototypes
print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...")
cls_names = sorted(set(img_hier))
prompts = [f"a photo of a {c}" for c in cls_names]
text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
with torch.no_grad():
txt_feats = self.model.get_text_features(**text_inputs)
txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
text_protos = txt_feats.cpu().numpy()
ensemble_preds = self.predict_labels_image_ensemble(
img_full, img_hier, text_protos, cls_names, alpha=0.7,
)
ensemble_acc = accuracy_score(img_hier, ensemble_preds)
print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%")
img_hier_metrics.update(img_hier_class)
img_hier_metrics['ensemble_accuracy'] = ensemble_acc
results['image_hierarchy'] = img_hier_metrics
# --- save confusion matrix figures ---
for key in ['text_hierarchy', 'image_hierarchy']:
fig = results[key]['figure']
fig.savefig(
os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['confusion_matrix'],
results[key]['labels'],
os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"),
)
plt.close(fig)
del text_full, img_full, text_hier_spec, img_hier_spec
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
# ==================================================================
# 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST (hierarchy only)
# ==================================================================
def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
print(f"\n{'=' * 60}")
print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST (Hierarchy only)")
print(f" Max samples: {max_samples}")
print(f"{'=' * 60}")
if dataloader is None:
_, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
expected_counts = expected_counts or dataset_counts
elif expected_counts is None:
raise ValueError("expected_counts must be provided when using a custom dataloader.")
results = {}
# --- text ---
print("\nExtracting baseline text embeddings...")
text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
self._validate_label_distribution(text_hier, expected_counts, "baseline text")
print(f" Baseline text shape: {text_emb.shape}")
text_metrics = compute_similarity_metrics(text_emb, text_hier)
text_class = self.evaluate_classification_performance(
text_emb, text_hier, "Baseline Text - Hierarchy", "Hierarchy", method="nn",
)
text_metrics.update(text_class)
results['text'] = {'hierarchy': text_metrics}
del text_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- image ---
print("\nExtracting baseline image embeddings...")
img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
self._validate_label_distribution(img_hier, expected_counts, "baseline image")
print(f" Baseline image shape: {img_emb.shape}")
img_metrics = compute_similarity_metrics(img_emb, img_hier)
img_class = self.evaluate_classification_performance(
img_emb, img_hier, "Baseline Image - Hierarchy", "Hierarchy", method="nn",
)
img_metrics.update(img_class)
results['image'] = {'hierarchy': img_metrics}
del img_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
for key in ['text', 'image']:
fig = results[key]['hierarchy']['figure']
fig.savefig(
os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['hierarchy']['confusion_matrix'],
results[key]['hierarchy']['labels'],
os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"),
)
plt.close(fig)
return results
# ==================================================================
# 5. Generic dataset evaluation (KAGL Marqo / Internal)
# ==================================================================
def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000):
"""Evaluate GAP-CLIP color + hierarchy performance on any dataset."""
print(f"\n{'=' * 60}")
print(f"Evaluating GAP-CLIP on {dataset_name} (Color + Hierarchy)")
print(f" Color (dims 0-{self.color_emb_dim - 1}) | "
f"Hierarchy (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})")
print(f"{'=' * 60}")
results = {}
# --- text ---
print("\nExtracting GAP-CLIP text embeddings...")
text_full, text_colors, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
print(f" Text shape: {text_full.shape}")
# text color
text_color_spec = text_full[:, :self.color_emb_dim]
text_color_metrics = compute_similarity_metrics(text_color_spec, text_colors)
text_color_class = self.evaluate_classification_performance(
text_color_spec, text_colors,
f"GAP-CLIP Text Color – {dataset_name}", "Color", method="nn",
)
text_color_metrics.update(text_color_class)
results['text_color'] = text_color_metrics
# text hierarchy
text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim]
text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
text_hier_class = self.evaluate_classification_performance(
text_hier_spec, text_hier,
f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
text_hier_metrics.update(text_hier_class)
results['text_hierarchy'] = text_hier_metrics
# --- image ---
print("\nExtracting GAP-CLIP image embeddings...")
img_full, img_colors, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
# image color
img_color_spec = img_full[:, :self.color_emb_dim]
img_color_metrics = compute_similarity_metrics(img_color_spec, img_colors)
img_color_class = self.evaluate_classification_performance(
img_color_spec, img_colors,
f"GAP-CLIP Image Color – {dataset_name}", "Color", method="nn",
)
img_color_metrics.update(img_color_class)
results['image_color'] = img_color_metrics
# image hierarchy (best of 64D vs 512D)
img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim]
spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
spec_class = self.evaluate_classification_performance(
img_hier_spec, img_hier,
f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn",
)
full_metrics = compute_similarity_metrics(img_full, img_hier)
full_class = self.evaluate_classification_performance(
img_full, img_hier,
f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn",
)
if full_class['accuracy'] >= spec_class['accuracy']:
print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%")
img_hier_metrics, img_hier_class = full_metrics, full_class
else:
print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%")
img_hier_metrics, img_hier_class = spec_metrics, spec_class
img_hier_metrics.update(img_hier_class)
results['image_hierarchy'] = img_hier_metrics
# --- save confusion matrices ---
prefix = dataset_name.lower().replace(" ", "_")
for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']:
fig = results[key]['figure']
fig.savefig(
os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key]['confusion_matrix'], results[key]['labels'],
os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"),
)
plt.close(fig)
del text_full, img_full
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000):
"""Evaluate baseline Fashion-CLIP color + hierarchy on any dataset."""
print(f"\n{'=' * 60}")
print(f"Evaluating Baseline Fashion-CLIP on {dataset_name} (Color + Hierarchy)")
print(f"{'=' * 60}")
results = {}
# --- text ---
print("\nExtracting baseline text embeddings...")
text_emb, text_colors, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
print(f" Baseline text shape: {text_emb.shape}")
text_color_metrics = compute_similarity_metrics(text_emb, text_colors)
text_color_class = self.evaluate_classification_performance(
text_emb, text_colors,
f"Baseline Text Color – {dataset_name}", "Color", method="nn",
)
text_color_metrics.update(text_color_class)
text_hier_metrics = compute_similarity_metrics(text_emb, text_hier)
text_hier_class = self.evaluate_classification_performance(
text_emb, text_hier,
f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
text_hier_metrics.update(text_hier_class)
results['text'] = {'color': text_color_metrics, 'hierarchy': text_hier_metrics}
del text_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- image ---
print("\nExtracting baseline image embeddings...")
img_emb, img_colors, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
print(f" Baseline image shape: {img_emb.shape}")
img_color_metrics = compute_similarity_metrics(img_emb, img_colors)
img_color_class = self.evaluate_classification_performance(
img_emb, img_colors,
f"Baseline Image Color – {dataset_name}", "Color", method="nn",
)
img_color_metrics.update(img_color_class)
img_hier_metrics = compute_similarity_metrics(img_emb, img_hier)
img_hier_class = self.evaluate_classification_performance(
img_emb, img_hier,
f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn",
)
img_hier_metrics.update(img_hier_class)
results['image'] = {'color': img_color_metrics, 'hierarchy': img_hier_metrics}
del img_emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
prefix = dataset_name.lower().replace(" ", "_")
for key in ['text', 'image']:
for attr in ['color', 'hierarchy']:
fig = results[key][attr]['figure']
fig.savefig(
os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.png"),
dpi=300, bbox_inches='tight',
)
self.save_confusion_matrix_table(
results[key][attr]['confusion_matrix'],
results[key][attr]['labels'],
os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.csv"),
)
plt.close(fig)
return results
# ==================================================================
# 6. Full evaluation across all datasets
# ==================================================================
def run_full_evaluation(self, max_samples=10000, batch_size=8):
"""Run color + hierarchy evaluation on all 3 datasets for both models."""
all_results = {}
# --- Fashion-MNIST ---
shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist(
max_samples=max_samples, batch_size=batch_size,
)
all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist(
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
)
all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist(
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
)
# --- KAGL Marqo ---
try:
kaggle_dataset = load_kaggle_marqo_with_hierarchy(
max_samples=max_samples,
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
raw_df=self.kaggle_raw_df,
)
if kaggle_dataset is not None and len(kaggle_dataset) > 0:
kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_results['kaggle_gap'] = self.evaluate_gap_clip_generic(
kaggle_dataloader, "KAGL Marqo", max_samples,
)
all_results['kaggle_baseline'] = self.evaluate_baseline_generic(
kaggle_dataloader, "KAGL Marqo", max_samples,
)
else:
print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.")
except Exception as e:
print(f"WARNING: Could not evaluate on KAGL Marqo: {e}")
# --- Internal (local validation) ---
try:
local_dataset = load_local_validation_with_hierarchy(
max_samples=max_samples,
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
raw_df=self.local_raw_df,
)
if local_dataset is not None and len(local_dataset) > 0:
local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_results['local_gap'] = self.evaluate_gap_clip_generic(
local_dataloader, "Internal", max_samples,
)
all_results['local_baseline'] = self.evaluate_baseline_generic(
local_dataloader, "Internal", max_samples,
)
else:
print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.")
except Exception as e:
print(f"WARNING: Could not evaluate on internal dataset: {e}")
# --- Print summary ---
print(f"\n{'=' * 70}")
print("COLOR + HIERARCHY NN ACCURACY EVALUATION SUMMARY (Table 3)")
print(f"{'=' * 70}")
for dataset_key, label in [
('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'),
('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'),
('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'),
('kaggle_baseline', 'KAGL Marqo (Baseline)'),
('local_gap', 'Internal (GAP-CLIP)'),
('local_baseline', 'Internal (Baseline)'),
]:
if dataset_key not in all_results:
continue
res = all_results[dataset_key]
print(f"\n{label}:")
# GAP-CLIP format
if 'text_color' in res:
tc = res['text_color']
ic = res['image_color']
print(f" Color – Text NN: {tc['accuracy']*100:.1f}% | Image NN: {ic['accuracy']*100:.1f}%")
if 'text_hierarchy' in res:
th = res['text_hierarchy']
ih = res['image_hierarchy']
print(f" Hierarchy – Text NN: {th['accuracy']*100:.1f}% | Image NN: {ih['accuracy']*100:.1f}%")
if 'ensemble_accuracy' in ih:
print(f" Hierarchy – Image Ensemble: {ih['ensemble_accuracy']*100:.1f}%")
# Baseline format
if 'text' in res and isinstance(res['text'], dict):
t = res['text']
i = res['image']
if 'color' in t:
print(f" Color – Text NN: {t['color']['accuracy']*100:.1f}% | Image NN: {i['color']['accuracy']*100:.1f}%")
if 'hierarchy' in t:
print(f" Hierarchy – Text NN: {t['hierarchy']['accuracy']*100:.1f}% | Image NN: {i['hierarchy']['accuracy']*100:.1f}%")
return all_results
# ============================================================================
# 7. Main
# ============================================================================
if __name__ == "__main__":
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
directory = 'main_model_analysis'
max_samples = 10000
evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)