| """ |
| Section 5.2 — Category Model Evaluation (Table 2) |
| ================================================== |
| |
| Evaluates GAP-CLIP vs the Fashion-CLIP baseline on hierarchy (category) |
| classification using three datasets: |
| - Fashion-MNIST (10 categories) |
| - KAGL Marqo (external, real-world fashion e-commerce) |
| - Internal validation dataset |
| |
| Produces hierarchy confusion matrices (text + image) for both models on each |
| dataset. |
| |
| Metrics match Table 2 in the paper: |
| - Text/image embedding NN accuracy |
| - Text/image embedding separation score |
| |
| Run directly: |
| python sec52_category_model_eval.py |
| |
| Paper reference: Section 5.2, Table 2. |
| """ |
|
|
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| import torch |
| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import difflib |
| from collections import defaultdict |
|
|
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.metrics import classification_report, accuracy_score |
| from sklearn.preprocessing import normalize |
|
|
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| from io import BytesIO |
|
|
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| from config import ( |
| ROOT_DIR, |
| main_model_path, |
| main_emb_dim, |
| hierarchy_model_path, |
| color_emb_dim, |
| hierarchy_emb_dim, |
| local_dataset_path, |
| column_local_image_path, |
| ) |
|
|
| from utils.datasets import ( |
| load_fashion_mnist_dataset, |
| ) |
| from utils.embeddings import extract_clip_embeddings |
| from utils.metrics import ( |
| compute_similarity_metrics, |
| compute_embedding_accuracy, |
| compute_centroid_accuracy, |
| predict_labels_from_embeddings, |
| create_confusion_matrix, |
| ) |
| from utils.model_loader import load_gap_clip, load_baseline_fashion_clip |
|
|
|
|
| |
| |
| |
|
|
| class KaggleHierarchyDataset(Dataset): |
| """KAGL Marqo dataset returning (image, description, color, hierarchy).""" |
|
|
| def __init__(self, dataframe, image_size=224): |
| self.dataframe = dataframe.reset_index(drop=True) |
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self): |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| image_data = row["image"] |
| if isinstance(image_data, dict) and "bytes" in image_data: |
| image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") |
| elif hasattr(image_data, "convert"): |
| image = image_data.convert("RGB") |
| else: |
| image = Image.open(BytesIO(image_data)).convert("RGB") |
| image = self.transform(image) |
| description = str(row["text"]) |
| color = str(row.get("baseColour", "unknown")).lower() |
| hierarchy = str(row["hierarchy"]) |
| return image, description, color, hierarchy |
|
|
|
|
| def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): |
| """Load KAGL Marqo dataset with hierarchy labels derived from articleType. |
| |
| Args: |
| raw_df: Pre-downloaded DataFrame to skip the HuggingFace download. |
| """ |
| if raw_df is not None: |
| df = raw_df.copy() |
| print(f"Using cached KAGL DataFrame for hierarchy evaluation: {len(df)} samples") |
| else: |
| from datasets import load_dataset |
| print("Loading KAGL Marqo dataset for hierarchy evaluation...") |
| dataset = load_dataset("Marqo/KAGL") |
| df = dataset["data"].to_pandas() |
| print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}") |
|
|
| |
| hierarchy_col = 'category2' |
|
|
| if hierarchy_col is None: |
| print("WARNING: No hierarchy column found in KAGL dataset") |
| return None |
|
|
| print(f"Using '{hierarchy_col}' as hierarchy source") |
| df = df.dropna(subset=["text", "image", hierarchy_col]) |
| df["hierarchy"] = df[hierarchy_col].astype(str).str.strip() |
|
|
| |
| if hierarchy_classes: |
| hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
| mapped = [] |
| for _, row in df.iterrows(): |
| kagl_type = row["hierarchy"].lower() |
| matched = None |
| |
| if kagl_type in hierarchy_classes_lower: |
| matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)] |
| else: |
| |
| for h_class in hierarchy_classes: |
| h_lower = h_class.lower() |
| if h_lower in kagl_type or kagl_type in h_lower: |
| matched = h_class |
| break |
| if matched is None: |
| close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6) |
| if close: |
| matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])] |
| mapped.append(matched) |
| df["hierarchy"] = mapped |
| df = df.dropna(subset=["hierarchy"]) |
| print(f"After hierarchy mapping: {len(df)} samples") |
|
|
| if len(df) > max_samples: |
| df = df.sample(n=max_samples, random_state=42) |
|
|
| print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: " |
| f"{sorted(df['hierarchy'].unique())}") |
| return KaggleHierarchyDataset(df) |
|
|
|
|
| |
| |
| |
|
|
| class LocalHierarchyDataset(Dataset): |
| """Local validation dataset returning (image, description, color, hierarchy).""" |
|
|
| def __init__(self, dataframe, image_size=224): |
| self.dataframe = dataframe.reset_index(drop=True) |
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self): |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| try: |
| img_path = row[column_local_image_path] |
| if not os.path.isabs(img_path): |
| img_path = os.path.join(ROOT_DIR, img_path) |
| image = Image.open(img_path).convert("RGB") |
| except Exception: |
| image = Image.new("RGB", (224, 224), color="gray") |
| image = self.transform(image) |
| description = str(row["text"]) |
| color = str(row.get("color", "unknown")) |
| hierarchy = str(row["hierarchy"]) |
| return image, description, color, hierarchy |
|
|
|
|
| def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): |
| """Load internal validation dataset with hierarchy labels. |
| |
| Args: |
| raw_df: Pre-loaded DataFrame to skip CSV read. |
| """ |
| if raw_df is not None: |
| df = raw_df.copy() |
| print(f"Using cached local DataFrame for hierarchy evaluation: {len(df)} samples") |
| else: |
| print("Loading local validation dataset for hierarchy evaluation...") |
| df = pd.read_csv(local_dataset_path) |
| print(f"Dataset loaded: {len(df)} samples") |
|
|
| df = df.dropna(subset=[column_local_image_path, "hierarchy"]) |
| df["hierarchy"] = df["hierarchy"].astype(str).str.strip() |
| df = df[df["hierarchy"].str.len() > 0] |
|
|
| if hierarchy_classes: |
| hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
| df["hierarchy_lower"] = df["hierarchy"].str.lower() |
| df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)] |
| |
| case_map = {h.lower(): h for h in hierarchy_classes} |
| df["hierarchy"] = df["hierarchy_lower"].map(case_map) |
| df = df.drop(columns=["hierarchy_lower"]) |
|
|
| print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes") |
|
|
| if len(df) > max_samples: |
| df = df.sample(n=max_samples, random_state=42) |
|
|
| print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}") |
| return LocalHierarchyDataset(df) |
|
|
|
|
| |
| |
| |
|
|
| class CategoryModelEvaluator: |
| """ |
| Produces hierarchy confusion matrices for GAP-CLIP and the |
| baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets. |
| """ |
|
|
| def __init__(self, device='mps', directory='gap_clip_confusion_matrices', |
| gap_clip_model=None, gap_clip_processor=None, |
| baseline_model=None, baseline_processor=None, |
| hierarchy_classes=None, |
| kaggle_raw_df=None, local_raw_df=None): |
| self.device = torch.device(device) if isinstance(device, str) else device |
| self.directory = directory |
| self.kaggle_raw_df = kaggle_raw_df |
| self.local_raw_df = local_raw_df |
| self.color_emb_dim = color_emb_dim |
| self.hierarchy_emb_dim = hierarchy_emb_dim |
| self.main_emb_dim = main_emb_dim |
| self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim |
| os.makedirs(self.directory, exist_ok=True) |
|
|
| |
| if hierarchy_classes is not None: |
| self.hierarchy_classes = hierarchy_classes |
| print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes") |
| else: |
| print("Loading hierarchy classes from hierarchy model...") |
| if not os.path.exists(hierarchy_model_path): |
| raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") |
| hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) |
| self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) |
| print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") |
|
|
| self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() |
| if self.validation_hierarchy_classes: |
| print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): " |
| f"{sorted(self.validation_hierarchy_classes)}") |
| else: |
| print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.") |
| self.validation_hierarchy_classes = self.hierarchy_classes |
|
|
| |
| if gap_clip_model is not None and gap_clip_processor is not None: |
| self.model = gap_clip_model |
| self.processor = gap_clip_processor |
| print("Using pre-loaded GAP-CLIP model") |
| else: |
| self.model, self.processor = load_gap_clip(main_model_path, self.device) |
| print("GAP-CLIP model loaded successfully") |
|
|
| |
| if baseline_model is not None and baseline_processor is not None: |
| self.baseline_model = baseline_model |
| self.baseline_processor = baseline_processor |
| print("Using pre-loaded baseline Fashion-CLIP model") |
| else: |
| self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) |
| print("Baseline Fashion-CLIP model loaded successfully") |
|
|
| |
| |
| |
| def _load_validation_hierarchy_classes(self): |
| if not os.path.exists(local_dataset_path): |
| print(f"Validation dataset not found at {local_dataset_path}") |
| return [] |
| try: |
| df = pd.read_csv(local_dataset_path) |
| except Exception as exc: |
| print(f"Failed to read validation dataset: {exc}") |
| return [] |
| if 'hierarchy' not in df.columns: |
| print("Validation dataset does not contain 'hierarchy' column.") |
| return [] |
| hierarchies = df['hierarchy'].dropna().astype(str).str.strip() |
| hierarchies = [h for h in hierarchies if h] |
| return sorted(set(hierarchies)) |
|
|
| def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8): |
| """ |
| Build one shared Fashion-MNIST dataset/dataloader to ensure every model |
| is evaluated on the exact same items. |
| """ |
| target_classes = self.validation_hierarchy_classes or self.hierarchy_classes |
| fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes) |
| dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
| hierarchy_counts = defaultdict(int) |
| if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: |
| for _, row in fashion_dataset.dataframe.iterrows(): |
| lid = int(row['label']) |
| hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1 |
|
|
| return fashion_dataset, dataloader, dict(hierarchy_counts) |
|
|
| @staticmethod |
| def _count_labels(labels): |
| counts = defaultdict(int) |
| for label in labels: |
| counts[label] += 1 |
| return dict(counts) |
|
|
| def _validate_label_distribution(self, labels, expected_counts, context): |
| observed = self._count_labels(labels) |
| if observed != expected_counts: |
| raise ValueError( |
| f"Label distribution mismatch in {context}. " |
| f"Expected {expected_counts}, observed {observed}" |
| ) |
|
|
| |
| |
| |
| def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): |
| """Full 512D embeddings from GAP-CLIP (text or image).""" |
| return extract_clip_embeddings( |
| self.model, self.processor, dataloader, self.device, |
| embedding_type=embedding_type, max_samples=max_samples, |
| desc=f"GAP-CLIP {embedding_type} embeddings", |
| ) |
|
|
| def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): |
| """L2-normalised embeddings from baseline Fashion-CLIP.""" |
| return extract_clip_embeddings( |
| self.baseline_model, self.baseline_processor, dataloader, self.device, |
| embedding_type=embedding_type, max_samples=max_samples, |
| desc=f"Baseline {embedding_type} embeddings", |
| ) |
|
|
| def predict_labels_nearest_neighbor(self, embeddings, labels): |
| """ |
| Predict labels using 1-NN on the same embedding set. |
| This matches the accuracy logic used in the evaluation pipeline. |
| """ |
| similarities = cosine_similarity(embeddings) |
| preds = [] |
| for i in range(len(embeddings)): |
| sims = similarities[i].copy() |
| sims[i] = -1.0 |
| nearest_neighbor_idx = int(np.argmax(sims)) |
| preds.append(labels[nearest_neighbor_idx]) |
| return preds |
|
|
| |
| |
| |
| def _compute_img_centroids(self, embeddings, labels): |
| emb_norm = normalize(embeddings, norm='l2') |
| centroids = {} |
| for label in sorted(set(labels)): |
| idx = [i for i, l in enumerate(labels) if l == label] |
| centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0] |
| return centroids |
|
|
| def predict_labels_image_ensemble(self, img_embeddings, labels, |
| text_protos, cls_names, alpha=0.5): |
| """Combine image centroids (512D) with text prototypes (512D).""" |
| img_norm = normalize(img_embeddings, norm='l2') |
| img_centroids = self._compute_img_centroids(img_norm, labels) |
| centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0) |
|
|
| preds = [] |
| for i in range(len(img_norm)): |
| v = img_norm[i:i + 1] |
| sim_img = cosine_similarity(v, centroid_mat)[0] |
| sim_txt = cosine_similarity(v, text_protos)[0] |
| scores = alpha * sim_img + (1 - alpha) * sim_txt |
| preds.append(cls_names[int(np.argmax(scores))]) |
| return preds |
|
|
| |
| |
| |
| def evaluate_classification_performance(self, embeddings, labels, |
| embedding_type="Embeddings", |
| label_type="Label", |
| method="nn"): |
| if method == "nn": |
| preds = self.predict_labels_nearest_neighbor(embeddings, labels) |
| elif method == "centroid": |
| preds = predict_labels_from_embeddings(embeddings, labels) |
| else: |
| raise ValueError(f"Unknown classification method: {method}") |
| acc = accuracy_score(labels, preds) |
| unique_labels = sorted(set(labels)) |
| fig, _, cm = create_confusion_matrix( |
| labels, preds, |
| f"{embedding_type} - {label_type} Classification ({method.upper()})", |
| label_type, |
| ) |
| report = classification_report(labels, preds, labels=unique_labels, |
| target_names=unique_labels, output_dict=True) |
| return { |
| 'accuracy': acc, |
| 'predictions': preds, |
| 'confusion_matrix': cm, |
| 'labels': unique_labels, |
| 'classification_report': report, |
| 'figure': fig, |
| } |
|
|
| def save_confusion_matrix_table(self, cm, labels, output_csv_path): |
| """ |
| Save confusion matrix values with per-row totals to CSV for auditing. |
| """ |
| cm_df = pd.DataFrame(cm, index=labels, columns=labels) |
| cm_df["row_total"] = cm_df.sum(axis=1) |
| cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()] |
| cm_df.to_csv(output_csv_path) |
|
|
| |
| |
| |
| def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): |
| print(f"\n{'=' * 60}") |
| print("Evaluating GAP-CLIP on Fashion-MNIST") |
| print(" Hierarchy embeddings (dims 16-79)") |
| print(f" Max samples: {max_samples}") |
| print(f"{'=' * 60}") |
|
|
| if dataloader is None: |
| fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) |
| expected_counts = expected_counts or dataset_counts |
| else: |
| fashion_dataset = getattr(dataloader, "dataset", None) |
| if expected_counts is None: |
| raise ValueError("expected_counts must be provided when using a custom dataloader.") |
|
|
| if fashion_dataset is not None and len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: |
| print(f"\nHierarchy distribution in dataset:") |
| for h in sorted(expected_counts): |
| print(f" {h}: {expected_counts[h]} samples") |
|
|
| results = {} |
|
|
| |
| print("\nExtracting full 512-dimensional GAP-CLIP embeddings...") |
| text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) |
| img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) |
| self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text") |
| self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image") |
| print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}") |
|
|
| |
| print("\n--- GAP-CLIP TEXT HIERARCHY (dims 16-79) ---") |
| text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] |
| print(f" Specialized text hierarchy shape: {text_hier_spec.shape}") |
|
|
| text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) |
| text_class = self.evaluate_classification_performance( |
| text_hier_spec, text_hier, |
| "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", |
| method="nn", |
| ) |
| text_metrics.update(text_class) |
| results['text_hierarchy'] = text_metrics |
|
|
| |
| print("\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---") |
| img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] |
| print(f" Specialized image hierarchy shape: {img_hier_spec.shape}") |
|
|
| print(" Testing specialized 64D...") |
| spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) |
| spec_class = self.evaluate_classification_performance( |
| img_hier_spec, img_hier, |
| "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", |
| method="nn", |
| ) |
|
|
| print(" Testing full 512D...") |
| full_metrics = compute_similarity_metrics(img_full, img_hier) |
| full_class = self.evaluate_classification_performance( |
| img_full, img_hier, |
| "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", |
| method="nn", |
| ) |
|
|
| if full_class['accuracy'] >= spec_class['accuracy']: |
| print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%") |
| img_metrics, img_class = full_metrics, full_class |
| else: |
| print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%") |
| img_metrics, img_class = spec_metrics, spec_class |
|
|
| |
| print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...") |
| cls_names = sorted(set(img_hier)) |
| prompts = [f"a photo of a {c}" for c in cls_names] |
| text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
| with torch.no_grad(): |
| txt_feats = self.model.get_text_features(**text_inputs) |
| txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) |
| text_protos = txt_feats.cpu().numpy() |
|
|
| ensemble_preds = self.predict_labels_image_ensemble( |
| img_full, img_hier, text_protos, cls_names, alpha=0.7, |
| ) |
| ensemble_acc = accuracy_score(img_hier, ensemble_preds) |
| print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%") |
|
|
| img_metrics.update(img_class) |
| img_metrics['ensemble_accuracy'] = ensemble_acc |
| results['image_hierarchy'] = img_metrics |
|
|
| |
| for key in ['text_hierarchy', 'image_hierarchy']: |
| fig = results[key]['figure'] |
| fig.savefig( |
| os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"), |
| dpi=300, bbox_inches='tight', |
| ) |
| self.save_confusion_matrix_table( |
| results[key]['confusion_matrix'], |
| results[key]['labels'], |
| os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"), |
| ) |
| plt.close(fig) |
|
|
| del text_full, img_full, text_hier_spec, img_hier_spec |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
| |
| |
| |
| def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): |
| print(f"\n{'=' * 60}") |
| print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST") |
| print(f" Max samples: {max_samples}") |
| print(f"{'=' * 60}") |
|
|
| if dataloader is None: |
| _, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) |
| expected_counts = expected_counts or dataset_counts |
| elif expected_counts is None: |
| raise ValueError("expected_counts must be provided when using a custom dataloader.") |
|
|
| results = {} |
|
|
| |
| print("\nExtracting baseline text embeddings...") |
| text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
| self._validate_label_distribution(text_hier, expected_counts, "baseline text") |
| print(f" Baseline text shape: {text_emb.shape}") |
|
|
| text_metrics = compute_similarity_metrics(text_emb, text_hier) |
| text_class = self.evaluate_classification_performance( |
| text_emb, text_hier, |
| "Baseline Fashion-CLIP Text - Hierarchy", "Hierarchy", |
| method="nn", |
| ) |
| text_metrics.update(text_class) |
| results['text'] = {'hierarchy': text_metrics} |
|
|
| del text_emb |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| print("\nExtracting baseline image embeddings...") |
| img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
| self._validate_label_distribution(img_hier, expected_counts, "baseline image") |
| print(f" Baseline image shape: {img_emb.shape}") |
|
|
| img_metrics = compute_similarity_metrics(img_emb, img_hier) |
| img_class = self.evaluate_classification_performance( |
| img_emb, img_hier, |
| "Baseline Fashion-CLIP Image - Hierarchy", "Hierarchy", |
| method="nn", |
| ) |
| img_metrics.update(img_class) |
| results['image'] = {'hierarchy': img_metrics} |
|
|
| del img_emb |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| for key in ['text', 'image']: |
| fig = results[key]['hierarchy']['figure'] |
| fig.savefig( |
| os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"), |
| dpi=300, bbox_inches='tight', |
| ) |
| self.save_confusion_matrix_table( |
| results[key]['hierarchy']['confusion_matrix'], |
| results[key]['hierarchy']['labels'], |
| os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"), |
| ) |
| plt.close(fig) |
|
|
| return results |
|
|
| |
| |
| |
| def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000): |
| """Evaluate GAP-CLIP hierarchy performance on any dataset.""" |
| print(f"\n{'=' * 60}") |
| print(f"Evaluating GAP-CLIP on {dataset_name}") |
| print(f" Hierarchy embeddings (dims 16-79)") |
| print(f"{'=' * 60}") |
|
|
| results = {} |
|
|
| |
| print("\nExtracting GAP-CLIP text embeddings...") |
| text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) |
| text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] |
| print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}") |
|
|
| text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) |
| text_class = self.evaluate_classification_performance( |
| text_hier_spec, text_hier, |
| f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", |
| ) |
| text_metrics.update(text_class) |
| results['text_hierarchy'] = text_metrics |
|
|
| |
| print("\nExtracting GAP-CLIP image embeddings...") |
| img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) |
| img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] |
|
|
| spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) |
| spec_class = self.evaluate_classification_performance( |
| img_hier_spec, img_hier, |
| f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn", |
| ) |
|
|
| full_metrics = compute_similarity_metrics(img_full, img_hier) |
| full_class = self.evaluate_classification_performance( |
| img_full, img_hier, |
| f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn", |
| ) |
|
|
| if full_class['accuracy'] >= spec_class['accuracy']: |
| print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%") |
| img_metrics, img_class = full_metrics, full_class |
| else: |
| print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%") |
| img_metrics, img_class = spec_metrics, spec_class |
|
|
| img_metrics.update(img_class) |
| results['image_hierarchy'] = img_metrics |
|
|
| |
| prefix = dataset_name.lower().replace(" ", "_") |
| for key in ['text_hierarchy', 'image_hierarchy']: |
| fig = results[key]['figure'] |
| fig.savefig( |
| os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"), |
| dpi=300, bbox_inches='tight', |
| ) |
| self.save_confusion_matrix_table( |
| results[key]['confusion_matrix'], results[key]['labels'], |
| os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"), |
| ) |
| plt.close(fig) |
|
|
| del text_full, img_full, text_hier_spec, img_hier_spec |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
| def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000): |
| """Evaluate baseline Fashion-CLIP hierarchy performance on any dataset.""" |
| print(f"\n{'=' * 60}") |
| print(f"Evaluating Baseline Fashion-CLIP on {dataset_name}") |
| print(f"{'=' * 60}") |
|
|
| results = {} |
|
|
| |
| print("\nExtracting baseline text embeddings...") |
| text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
| print(f" Baseline text shape: {text_emb.shape}") |
|
|
| text_metrics = compute_similarity_metrics(text_emb, text_hier) |
| text_class = self.evaluate_classification_performance( |
| text_emb, text_hier, |
| f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", |
| ) |
| text_metrics.update(text_class) |
| results['text'] = {'hierarchy': text_metrics} |
|
|
| del text_emb |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| print("\nExtracting baseline image embeddings...") |
| img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
| print(f" Baseline image shape: {img_emb.shape}") |
|
|
| img_metrics = compute_similarity_metrics(img_emb, img_hier) |
| img_class = self.evaluate_classification_performance( |
| img_emb, img_hier, |
| f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn", |
| ) |
| img_metrics.update(img_class) |
| results['image'] = {'hierarchy': img_metrics} |
|
|
| del img_emb |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| prefix = dataset_name.lower().replace(" ", "_") |
| for key in ['text', 'image']: |
| fig = results[key]['hierarchy']['figure'] |
| fig.savefig( |
| os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"), |
| dpi=300, bbox_inches='tight', |
| ) |
| self.save_confusion_matrix_table( |
| results[key]['hierarchy']['confusion_matrix'], |
| results[key]['hierarchy']['labels'], |
| os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.csv"), |
| ) |
| plt.close(fig) |
|
|
| return results |
|
|
| |
| |
| |
| def run_full_evaluation(self, max_samples=10000, batch_size=8): |
| """Run hierarchy evaluation on all 3 datasets for both models.""" |
| all_results = {} |
|
|
| |
| shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist( |
| max_samples=max_samples, batch_size=batch_size, |
| ) |
| all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist( |
| max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, |
| ) |
| all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist( |
| max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, |
| ) |
|
|
| |
| try: |
| kaggle_dataset = load_kaggle_marqo_with_hierarchy( |
| max_samples=max_samples, |
| hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, |
| raw_df=self.kaggle_raw_df, |
| ) |
| if kaggle_dataset is not None and len(kaggle_dataset) > 0: |
| kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| all_results['kaggle_gap'] = self.evaluate_gap_clip_generic( |
| kaggle_dataloader, "KAGL Marqo", max_samples, |
| ) |
| all_results['kaggle_baseline'] = self.evaluate_baseline_generic( |
| kaggle_dataloader, "KAGL Marqo", max_samples, |
| ) |
| else: |
| print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.") |
| except Exception as e: |
| print(f"WARNING: Could not evaluate on KAGL Marqo: {e}") |
|
|
| |
| try: |
| local_dataset = load_local_validation_with_hierarchy( |
| max_samples=max_samples, |
| hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, |
| raw_df=self.local_raw_df, |
| ) |
| if local_dataset is not None and len(local_dataset) > 0: |
| local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| all_results['local_gap'] = self.evaluate_gap_clip_generic( |
| local_dataloader, "Internal", max_samples, |
| ) |
| all_results['local_baseline'] = self.evaluate_baseline_generic( |
| local_dataloader, "Internal", max_samples, |
| ) |
| else: |
| print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.") |
| except Exception as e: |
| print(f"WARNING: Could not evaluate on internal dataset: {e}") |
|
|
| |
| print(f"\n{'=' * 70}") |
| print("CATEGORY MODEL EVALUATION SUMMARY") |
| print(f"{'=' * 70}") |
| for dataset_key, label in [ |
| ('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'), |
| ('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'), |
| ('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'), |
| ('kaggle_baseline', 'KAGL Marqo (Baseline)'), |
| ('local_gap', 'Internal (GAP-CLIP)'), |
| ('local_baseline', 'Internal (Baseline)'), |
| ]: |
| if dataset_key not in all_results: |
| continue |
| res = all_results[dataset_key] |
| print(f"\n{label}:") |
| if 'text_hierarchy' in res: |
| t = res['text_hierarchy'] |
| i = res['image_hierarchy'] |
| print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") |
| print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") |
| elif 'text' in res: |
| t = res['text']['hierarchy'] |
| i = res['image']['hierarchy'] |
| print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") |
| print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") |
|
|
| return all_results |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| directory = 'gap_clip_confusion_matrices' |
| max_samples = 10000 |
|
|
| evaluator = CategoryModelEvaluator(device=device, directory=directory) |
| evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8) |
|
|