""" Section 5.3.3 Nearest-Neighbour Classification Accuracy (Table 3) ================================================================== Evaluates the full GAP-CLIP embedding on three datasets and compares with the patrickjohncyh/fashion-clip baseline — **color and hierarchy**. - Fashion-MNIST (public benchmark, 10 clothing categories) - KAGL Marqo HuggingFace dataset (diverse fashion, colour + category labels) - Internal local validation set (50 k images) For each dataset the ``ColorHierarchyEvaluator`` class extracts: * **Color slice** (dims 0–15): nearest-neighbour accuracy per colour class. * **Hierarchy slice** (dims 16–79): nearest-neighbour accuracy per category, plus 64-D vs 512-D comparison and image + text-prototype ensemble. Results feed directly into **Table 3** of the paper. The hierarchy mapping for Kaggle uses the same approach as in ``sec52_category_model_eval.py`` (exact match -> substring -> fuzzy on ``category2``). See also: - Section 5.1 (``sec51_color_model_eval.py``) – standalone colour model - Section 5.2 (``sec52_category_model_eval.py``) – confusion-matrix analysis - Section 5.3.6 (``sec536_embedding_structure.py``) – embedding-structure validation """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import difflib import warnings import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from collections import defaultdict from io import BytesIO from PIL import Image from sklearn.metrics import accuracy_score, classification_report from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import normalize from torch.utils.data import DataLoader, Dataset from torchvision import transforms warnings.filterwarnings('ignore') from config import ( ROOT_DIR, color_emb_dim, column_local_image_path, hierarchy_emb_dim, hierarchy_model_path, local_dataset_path, main_emb_dim, main_model_path, ) from utils.datasets import ( FashionMNISTDataset, LocalDataset, collate_fn_filter_none, load_fashion_mnist_dataset, load_local_validation_dataset, ) from utils.embeddings import extract_clip_embeddings from utils.metrics import ( compute_similarity_metrics, compute_centroid_accuracy, predict_labels_from_embeddings, create_confusion_matrix, ) from utils.model_loader import load_gap_clip, load_baseline_fashion_clip from training.hierarchy_model import HierarchyExtractor # --------------------------------------------------------------------------- # Hierarchy label normalisation (same as sec536_embedding_structure.py) # Maps long internal taxonomy strings -> clean labels like "top", "pant", etc. # --------------------------------------------------------------------------- NORMALIZED_HIERARCHY_CLASSES = [ "accessories", "bodysuits", "bras", "coat", "dress", "jacket", "legging", "pant", "polo", "shirt", "shoes", "short", "skirt", "socks", "sweater", "swimwear", "top", "underwear", ] _HIERARCHY_EXTRACTOR = HierarchyExtractor(NORMALIZED_HIERARCHY_CLASSES, verbose=False) _SYNONYMS = { "t-shirt/top": "top", "top": "top", "tee": "top", "t-shirt": "top", "shirt": "shirt", "shirts": "shirt", "pullover": "sweater", "sweater": "sweater", "coat": "coat", "jacket": "jacket", "outerwear": "coat", "outer": "coat", "trouser": "pant", "trousers": "pant", "pants": "pant", "pant": "pant", "jeans": "pant", "dress": "dress", "skirt": "skirt", "shorts": "short", "short": "short", "sandal": "shoes", "sneaker": "shoes", "ankle boot": "shoes", "shoe": "shoes", "shoes": "shoes", "flip flops": "shoes", "footwear": "shoes", "shoe accessories": "shoes", "boots": "shoes", "bag": "accessories", "bags": "accessories", "accessory": "accessories", "accessories": "accessories", "belts": "accessories", "eyewear": "accessories", "jewellery": "accessories", "jewelry": "accessories", "headwear": "accessories", "wallets": "accessories", "watches": "accessories", "mufflers": "accessories", "scarves": "accessories", "stoles": "accessories", "ties": "accessories", "sunglasses": "accessories", "scarf & tie": "accessories", "scarf/tie": "accessories", "belt": "accessories", "topwear": "top", "bottomwear": "pant", "innerwear": "underwear", "loungewear and nightwear": "underwear", "saree": "dress", } _EXTRA_KEYWORDS = [ ("capri", "pant"), ("denim", "pant"), ("skinny", "pant"), ("boyfriend", "pant"), ("graphic", "top"), ("longsleeve", "top"), ("leather", "jacket"), ] def normalize_hierarchy_label(raw_label: str) -> str: """Map any hierarchy string to a clean normalised label.""" label = str(raw_label).strip().lower() exact = _SYNONYMS.get(label) if exact is not None: return exact result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label) if result: return result for keyword, category in _EXTRA_KEYWORDS: if keyword in label: return category return label # ============================================================================ # 1. Dataset utilities (hierarchy mapping matches sec52) # ============================================================================ class KaggleHierarchyDataset(Dataset): """KAGL Marqo dataset returning (image, description, color, hierarchy).""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe.reset_index(drop=True) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] image_data = row["image"] if isinstance(image_data, dict) and "bytes" in image_data: image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") elif hasattr(image_data, "convert"): image = image_data.convert("RGB") else: image = Image.open(BytesIO(image_data)).convert("RGB") image = self.transform(image) description = str(row["text"]) color = str(row.get("baseColour", "unknown")).lower() hierarchy = str(row["hierarchy"]) return image, description, color, hierarchy def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): """Load KAGL Marqo dataset with hierarchy labels derived from category2. Mapping: exact match -> substring match -> fuzzy match (same as sec52). """ if raw_df is not None: df = raw_df.copy() print(f"Using cached KAGL DataFrame: {len(df)} samples") else: from datasets import load_dataset print("Loading KAGL Marqo dataset...") dataset = load_dataset("Marqo/KAGL") df = dataset["data"].to_pandas() print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}") hierarchy_col = 'category2' print(f"Using '{hierarchy_col}' as hierarchy source") df = df.dropna(subset=["text", "image", hierarchy_col]) df["hierarchy"] = df[hierarchy_col].astype(str).str.strip() # Normalise every category2 value through the synonym/extractor pipeline df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label) if hierarchy_classes: hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] mapped = [] for _, row in df.iterrows(): kagl_type = row["hierarchy"].lower() matched = None # Exact match (after normalisation most will hit here) if kagl_type in hierarchy_classes_lower: matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)] else: # Substring match for h_class in hierarchy_classes: h_lower = h_class.lower() if h_lower in kagl_type or kagl_type in h_lower: matched = h_class break if matched is None: close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6) if close: matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])] mapped.append(matched) df["hierarchy"] = mapped df = df.dropna(subset=["hierarchy"]) print(f"After hierarchy mapping: {len(df)} samples") # Normalise color column if "baseColour" in df.columns: df["baseColour"] = df["baseColour"].fillna("unknown").astype(str).str.lower().str.replace("grey", "gray") else: df["baseColour"] = "unknown" df = df.dropna(subset=["text", "image"]) if len(df) > max_samples: df = df.sample(n=max_samples, random_state=42) print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: " f"{sorted(df['hierarchy'].unique())}") return KaggleHierarchyDataset(df) class LocalHierarchyDataset(Dataset): """Local validation dataset returning (image, description, color, hierarchy).""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe.reset_index(drop=True) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] try: img_path = row[column_local_image_path] if not os.path.isabs(img_path): img_path = os.path.join(ROOT_DIR, img_path) image = Image.open(img_path).convert("RGB") except Exception: image = Image.new("RGB", (224, 224), color="gray") image = self.transform(image) description = str(row["text"]) color = str(row.get("color", "unknown")) hierarchy = str(row["hierarchy"]) return image, description, color, hierarchy def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): """Load internal validation dataset with hierarchy labels.""" if raw_df is not None: df = raw_df.copy() print(f"Using cached local DataFrame: {len(df)} samples") else: print("Loading local validation dataset...") df = pd.read_csv(local_dataset_path) print(f"Dataset loaded: {len(df)} samples") df = df.dropna(subset=[column_local_image_path, "hierarchy"]) df["hierarchy"] = df["hierarchy"].astype(str).str.strip() df = df[df["hierarchy"].str.len() > 0] # Normalise raw taxonomy strings to clean labels df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label) if hierarchy_classes: hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] df["hierarchy_lower"] = df["hierarchy"].str.lower() df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)] case_map = {h.lower(): h for h in hierarchy_classes} df["hierarchy"] = df["hierarchy_lower"].map(case_map) df = df.drop(columns=["hierarchy_lower"]) print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes") if len(df) > max_samples: df = df.sample(n=max_samples, random_state=42) print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}") return LocalHierarchyDataset(df) # ============================================================================ # 2. Evaluator # ============================================================================ class ColorHierarchyEvaluator: """ Evaluates color and hierarchy NN classification accuracy for GAP-CLIP and the baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and the internal validation dataset. """ def __init__(self, device='mps', directory='main_model_analysis', gap_clip_model=None, gap_clip_processor=None, baseline_model=None, baseline_processor=None, hierarchy_classes=None, kaggle_raw_df=None, local_raw_df=None): self.device = torch.device(device) if isinstance(device, str) else device self.directory = directory self.kaggle_raw_df = kaggle_raw_df self.local_raw_df = local_raw_df self.color_emb_dim = color_emb_dim self.hierarchy_emb_dim = hierarchy_emb_dim self.main_emb_dim = main_emb_dim self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim os.makedirs(self.directory, exist_ok=True) # --- hierarchy classes --- if hierarchy_classes is not None: self.hierarchy_classes = hierarchy_classes print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes") else: print("Loading hierarchy classes from hierarchy model...") if not os.path.exists(hierarchy_model_path): raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() if self.validation_hierarchy_classes: print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): " f"{sorted(self.validation_hierarchy_classes)}") else: print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.") self.validation_hierarchy_classes = self.hierarchy_classes # --- load GAP-CLIP --- if gap_clip_model is not None and gap_clip_processor is not None: self.model = gap_clip_model self.processor = gap_clip_processor print("Using pre-loaded GAP-CLIP model") else: self.model, self.processor = load_gap_clip(main_model_path, self.device) print("GAP-CLIP model loaded successfully") # --- baseline Fashion-CLIP --- if baseline_model is not None and baseline_processor is not None: self.baseline_model = baseline_model self.baseline_processor = baseline_processor print("Using pre-loaded baseline Fashion-CLIP model") else: self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) print("Baseline Fashion-CLIP model loaded successfully") # ------------------------------------------------------------------ # helpers # ------------------------------------------------------------------ def _load_validation_hierarchy_classes(self): """Load hierarchy classes from local CSV, normalised to clean labels.""" if not os.path.exists(local_dataset_path): print(f"Validation dataset not found at {local_dataset_path}") return [] try: df = pd.read_csv(local_dataset_path) except Exception as exc: print(f"Failed to read validation dataset: {exc}") return [] if 'hierarchy' not in df.columns: print("Validation dataset does not contain 'hierarchy' column.") return [] raw = df['hierarchy'].dropna().astype(str).str.strip() normalised = sorted(set( normalize_hierarchy_label(h) for h in raw if h )) # Keep only labels that belong to the known set normalised = [h for h in normalised if h in NORMALIZED_HIERARCHY_CLASSES] print(f"Normalised validation hierarchy classes: {normalised}") return normalised def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8): """Build one shared Fashion-MNIST dataset/dataloader. Uses NORMALIZED_HIERARCHY_CLASSES so that Fashion-MNIST labels are mapped to clean short names (top, pant, shoes, sweater, coat, …). """ target_classes = self.validation_hierarchy_classes or NORMALIZED_HIERARCHY_CLASSES fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes) # Normalise whatever label_mapping produced (e.g. "Coat" -> "coat") if fashion_dataset.label_mapping: fashion_dataset.label_mapping = { k: normalize_hierarchy_label(v) if v else v for k, v in fashion_dataset.label_mapping.items() } dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0) hierarchy_counts = defaultdict(int) if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: for _, row in fashion_dataset.dataframe.iterrows(): lid = int(row['label']) hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1 return fashion_dataset, dataloader, dict(hierarchy_counts) @staticmethod def _count_labels(labels): counts = defaultdict(int) for label in labels: counts[label] += 1 return dict(counts) def _validate_label_distribution(self, labels, expected_counts, context): observed = self._count_labels(labels) if observed != expected_counts: raise ValueError( f"Label distribution mismatch in {context}. " f"Expected {expected_counts}, observed {observed}" ) # ------------------------------------------------------------------ # embedding extraction # ------------------------------------------------------------------ def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): """Full 512D embeddings from GAP-CLIP.""" return extract_clip_embeddings( self.model, self.processor, dataloader, self.device, embedding_type=embedding_type, max_samples=max_samples, desc=f"GAP-CLIP {embedding_type} embeddings", ) def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): """L2-normalised embeddings from baseline Fashion-CLIP.""" return extract_clip_embeddings( self.baseline_model, self.baseline_processor, dataloader, self.device, embedding_type=embedding_type, max_samples=max_samples, desc=f"Baseline {embedding_type} embeddings", ) # ------------------------------------------------------------------ # prediction methods # ------------------------------------------------------------------ def predict_labels_nearest_neighbor(self, embeddings, labels): """Predict labels using 1-NN on the same embedding set.""" similarities = cosine_similarity(embeddings) preds = [] for i in range(len(embeddings)): sims = similarities[i].copy() sims[i] = -1.0 nearest_neighbor_idx = int(np.argmax(sims)) preds.append(labels[nearest_neighbor_idx]) return preds def _compute_img_centroids(self, embeddings, labels): emb_norm = normalize(embeddings, norm='l2') centroids = {} for label in sorted(set(labels)): idx = [i for i, l in enumerate(labels) if l == label] centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0] return centroids def predict_labels_image_ensemble(self, img_embeddings, labels, text_protos, cls_names, alpha=0.5): """Combine image centroids (512D) with text prototypes (512D).""" img_norm = normalize(img_embeddings, norm='l2') img_centroids = self._compute_img_centroids(img_norm, labels) centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0) preds = [] for i in range(len(img_norm)): v = img_norm[i:i + 1] sim_img = cosine_similarity(v, centroid_mat)[0] sim_txt = cosine_similarity(v, text_protos)[0] scores = alpha * sim_img + (1 - alpha) * sim_txt preds.append(cls_names[int(np.argmax(scores))]) return preds # ------------------------------------------------------------------ # classification evaluation # ------------------------------------------------------------------ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Hierarchy", method="nn"): if method == "nn": preds = self.predict_labels_nearest_neighbor(embeddings, labels) elif method == "centroid": preds = predict_labels_from_embeddings(embeddings, labels) else: raise ValueError(f"Unknown classification method: {method}") acc = accuracy_score(labels, preds) unique_labels = sorted(set(labels)) fig, _, cm = create_confusion_matrix( labels, preds, f"{embedding_type} - {label_type} Classification ({method.upper()})", label_type, ) report = classification_report(labels, preds, labels=unique_labels, target_names=unique_labels, output_dict=True) return { 'accuracy': acc, 'predictions': preds, 'confusion_matrix': cm, 'labels': unique_labels, 'classification_report': report, 'figure': fig, } def save_confusion_matrix_table(self, cm, labels, output_csv_path): cm_df = pd.DataFrame(cm, index=labels, columns=labels) cm_df["row_total"] = cm_df.sum(axis=1) cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()] cm_df.to_csv(output_csv_path) # ================================================================== # 3. GAP-CLIP evaluation on Fashion-MNIST (hierarchy only — no color) # ================================================================== def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): print(f"\n{'=' * 60}") print("Evaluating GAP-CLIP on Fashion-MNIST (Hierarchy only)") print(f" Hierarchy embeddings (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})") print(f" Max samples: {max_samples}") print(f"{'=' * 60}") if dataloader is None: fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) expected_counts = expected_counts or dataset_counts else: if expected_counts is None: raise ValueError("expected_counts must be provided when using a custom dataloader.") results = {} # --- full 512D embeddings (text & image) --- print("\nExtracting full 512-dimensional GAP-CLIP embeddings...") text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text") self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image") print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}") # ===== HIERARCHY (dims 16-79) ===== print(f"\n--- GAP-CLIP TEXT HIERARCHY (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1}) ---") text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim] print(f" Specialized text hierarchy shape: {text_hier_spec.shape}") text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier) text_hier_class = self.evaluate_classification_performance( text_hier_spec, text_hier, "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", method="nn", ) text_hier_metrics.update(text_hier_class) results['text_hierarchy'] = text_hier_metrics # IMAGE: 64D vs 512D print(f"\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---") img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim] print(f" Specialized image hierarchy shape: {img_hier_spec.shape}") print(" Testing specialized 64D...") spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) spec_class = self.evaluate_classification_performance( img_hier_spec, img_hier, "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", method="nn", ) print(" Testing full 512D...") full_metrics = compute_similarity_metrics(img_full, img_hier) full_class = self.evaluate_classification_performance( img_full, img_hier, "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", method="nn", ) if full_class['accuracy'] >= spec_class['accuracy']: print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%") img_hier_metrics, img_hier_class = full_metrics, full_class else: print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%") img_hier_metrics, img_hier_class = spec_metrics, spec_class # ensemble image + text prototypes print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...") cls_names = sorted(set(img_hier)) prompts = [f"a photo of a {c}" for c in cls_names] text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} with torch.no_grad(): txt_feats = self.model.get_text_features(**text_inputs) txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) text_protos = txt_feats.cpu().numpy() ensemble_preds = self.predict_labels_image_ensemble( img_full, img_hier, text_protos, cls_names, alpha=0.7, ) ensemble_acc = accuracy_score(img_hier, ensemble_preds) print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%") img_hier_metrics.update(img_hier_class) img_hier_metrics['ensemble_accuracy'] = ensemble_acc results['image_hierarchy'] = img_hier_metrics # --- save confusion matrix figures --- for key in ['text_hierarchy', 'image_hierarchy']: fig = results[key]['figure'] fig.savefig( os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['confusion_matrix'], results[key]['labels'], os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"), ) plt.close(fig) del text_full, img_full, text_hier_spec, img_hier_spec if torch.cuda.is_available(): torch.cuda.empty_cache() return results # ================================================================== # 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST (hierarchy only) # ================================================================== def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): print(f"\n{'=' * 60}") print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST (Hierarchy only)") print(f" Max samples: {max_samples}") print(f"{'=' * 60}") if dataloader is None: _, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) expected_counts = expected_counts or dataset_counts elif expected_counts is None: raise ValueError("expected_counts must be provided when using a custom dataloader.") results = {} # --- text --- print("\nExtracting baseline text embeddings...") text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) self._validate_label_distribution(text_hier, expected_counts, "baseline text") print(f" Baseline text shape: {text_emb.shape}") text_metrics = compute_similarity_metrics(text_emb, text_hier) text_class = self.evaluate_classification_performance( text_emb, text_hier, "Baseline Text - Hierarchy", "Hierarchy", method="nn", ) text_metrics.update(text_class) results['text'] = {'hierarchy': text_metrics} del text_emb if torch.cuda.is_available(): torch.cuda.empty_cache() # --- image --- print("\nExtracting baseline image embeddings...") img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) self._validate_label_distribution(img_hier, expected_counts, "baseline image") print(f" Baseline image shape: {img_emb.shape}") img_metrics = compute_similarity_metrics(img_emb, img_hier) img_class = self.evaluate_classification_performance( img_emb, img_hier, "Baseline Image - Hierarchy", "Hierarchy", method="nn", ) img_metrics.update(img_class) results['image'] = {'hierarchy': img_metrics} del img_emb if torch.cuda.is_available(): torch.cuda.empty_cache() for key in ['text', 'image']: fig = results[key]['hierarchy']['figure'] fig.savefig( os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['hierarchy']['confusion_matrix'], results[key]['hierarchy']['labels'], os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"), ) plt.close(fig) return results # ================================================================== # 5. Generic dataset evaluation (KAGL Marqo / Internal) # ================================================================== def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000): """Evaluate GAP-CLIP color + hierarchy performance on any dataset.""" print(f"\n{'=' * 60}") print(f"Evaluating GAP-CLIP on {dataset_name} (Color + Hierarchy)") print(f" Color (dims 0-{self.color_emb_dim - 1}) | " f"Hierarchy (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})") print(f"{'=' * 60}") results = {} # --- text --- print("\nExtracting GAP-CLIP text embeddings...") text_full, text_colors, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) print(f" Text shape: {text_full.shape}") # text color text_color_spec = text_full[:, :self.color_emb_dim] text_color_metrics = compute_similarity_metrics(text_color_spec, text_colors) text_color_class = self.evaluate_classification_performance( text_color_spec, text_colors, f"GAP-CLIP Text Color – {dataset_name}", "Color", method="nn", ) text_color_metrics.update(text_color_class) results['text_color'] = text_color_metrics # text hierarchy text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim] text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier) text_hier_class = self.evaluate_classification_performance( text_hier_spec, text_hier, f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) text_hier_metrics.update(text_hier_class) results['text_hierarchy'] = text_hier_metrics # --- image --- print("\nExtracting GAP-CLIP image embeddings...") img_full, img_colors, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) # image color img_color_spec = img_full[:, :self.color_emb_dim] img_color_metrics = compute_similarity_metrics(img_color_spec, img_colors) img_color_class = self.evaluate_classification_performance( img_color_spec, img_colors, f"GAP-CLIP Image Color – {dataset_name}", "Color", method="nn", ) img_color_metrics.update(img_color_class) results['image_color'] = img_color_metrics # image hierarchy (best of 64D vs 512D) img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim] spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) spec_class = self.evaluate_classification_performance( img_hier_spec, img_hier, f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn", ) full_metrics = compute_similarity_metrics(img_full, img_hier) full_class = self.evaluate_classification_performance( img_full, img_hier, f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn", ) if full_class['accuracy'] >= spec_class['accuracy']: print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%") img_hier_metrics, img_hier_class = full_metrics, full_class else: print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%") img_hier_metrics, img_hier_class = spec_metrics, spec_class img_hier_metrics.update(img_hier_class) results['image_hierarchy'] = img_hier_metrics # --- save confusion matrices --- prefix = dataset_name.lower().replace(" ", "_") for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: fig = results[key]['figure'] fig.savefig( os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['confusion_matrix'], results[key]['labels'], os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"), ) plt.close(fig) del text_full, img_full if torch.cuda.is_available(): torch.cuda.empty_cache() return results def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000): """Evaluate baseline Fashion-CLIP color + hierarchy on any dataset.""" print(f"\n{'=' * 60}") print(f"Evaluating Baseline Fashion-CLIP on {dataset_name} (Color + Hierarchy)") print(f"{'=' * 60}") results = {} # --- text --- print("\nExtracting baseline text embeddings...") text_emb, text_colors, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) print(f" Baseline text shape: {text_emb.shape}") text_color_metrics = compute_similarity_metrics(text_emb, text_colors) text_color_class = self.evaluate_classification_performance( text_emb, text_colors, f"Baseline Text Color – {dataset_name}", "Color", method="nn", ) text_color_metrics.update(text_color_class) text_hier_metrics = compute_similarity_metrics(text_emb, text_hier) text_hier_class = self.evaluate_classification_performance( text_emb, text_hier, f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) text_hier_metrics.update(text_hier_class) results['text'] = {'color': text_color_metrics, 'hierarchy': text_hier_metrics} del text_emb if torch.cuda.is_available(): torch.cuda.empty_cache() # --- image --- print("\nExtracting baseline image embeddings...") img_emb, img_colors, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) print(f" Baseline image shape: {img_emb.shape}") img_color_metrics = compute_similarity_metrics(img_emb, img_colors) img_color_class = self.evaluate_classification_performance( img_emb, img_colors, f"Baseline Image Color – {dataset_name}", "Color", method="nn", ) img_color_metrics.update(img_color_class) img_hier_metrics = compute_similarity_metrics(img_emb, img_hier) img_hier_class = self.evaluate_classification_performance( img_emb, img_hier, f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) img_hier_metrics.update(img_hier_class) results['image'] = {'color': img_color_metrics, 'hierarchy': img_hier_metrics} del img_emb if torch.cuda.is_available(): torch.cuda.empty_cache() prefix = dataset_name.lower().replace(" ", "_") for key in ['text', 'image']: for attr in ['color', 'hierarchy']: fig = results[key][attr]['figure'] fig.savefig( os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key][attr]['confusion_matrix'], results[key][attr]['labels'], os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.csv"), ) plt.close(fig) return results # ================================================================== # 6. Full evaluation across all datasets # ================================================================== def run_full_evaluation(self, max_samples=10000, batch_size=8): """Run color + hierarchy evaluation on all 3 datasets for both models.""" all_results = {} # --- Fashion-MNIST --- shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist( max_samples=max_samples, batch_size=batch_size, ) all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist( max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist( max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) # --- KAGL Marqo --- try: kaggle_dataset = load_kaggle_marqo_with_hierarchy( max_samples=max_samples, hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, raw_df=self.kaggle_raw_df, ) if kaggle_dataset is not None and len(kaggle_dataset) > 0: kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_results['kaggle_gap'] = self.evaluate_gap_clip_generic( kaggle_dataloader, "KAGL Marqo", max_samples, ) all_results['kaggle_baseline'] = self.evaluate_baseline_generic( kaggle_dataloader, "KAGL Marqo", max_samples, ) else: print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.") except Exception as e: print(f"WARNING: Could not evaluate on KAGL Marqo: {e}") # --- Internal (local validation) --- try: local_dataset = load_local_validation_with_hierarchy( max_samples=max_samples, hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, raw_df=self.local_raw_df, ) if local_dataset is not None and len(local_dataset) > 0: local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_results['local_gap'] = self.evaluate_gap_clip_generic( local_dataloader, "Internal", max_samples, ) all_results['local_baseline'] = self.evaluate_baseline_generic( local_dataloader, "Internal", max_samples, ) else: print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.") except Exception as e: print(f"WARNING: Could not evaluate on internal dataset: {e}") # --- Print summary --- print(f"\n{'=' * 70}") print("COLOR + HIERARCHY NN ACCURACY EVALUATION SUMMARY (Table 3)") print(f"{'=' * 70}") for dataset_key, label in [ ('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'), ('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'), ('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'), ('kaggle_baseline', 'KAGL Marqo (Baseline)'), ('local_gap', 'Internal (GAP-CLIP)'), ('local_baseline', 'Internal (Baseline)'), ]: if dataset_key not in all_results: continue res = all_results[dataset_key] print(f"\n{label}:") # GAP-CLIP format if 'text_color' in res: tc = res['text_color'] ic = res['image_color'] print(f" Color – Text NN: {tc['accuracy']*100:.1f}% | Image NN: {ic['accuracy']*100:.1f}%") if 'text_hierarchy' in res: th = res['text_hierarchy'] ih = res['image_hierarchy'] print(f" Hierarchy – Text NN: {th['accuracy']*100:.1f}% | Image NN: {ih['accuracy']*100:.1f}%") if 'ensemble_accuracy' in ih: print(f" Hierarchy – Image Ensemble: {ih['ensemble_accuracy']*100:.1f}%") # Baseline format if 'text' in res and isinstance(res['text'], dict): t = res['text'] i = res['image'] if 'color' in t: print(f" Color – Text NN: {t['color']['accuracy']*100:.1f}% | Image NN: {i['color']['accuracy']*100:.1f}%") if 'hierarchy' in t: print(f" Hierarchy – Text NN: {t['hierarchy']['accuracy']*100:.1f}% | Image NN: {i['hierarchy']['accuracy']*100:.1f}%") return all_results # ============================================================================ # 7. Main # ============================================================================ if __name__ == "__main__": device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") directory = 'main_model_analysis' max_samples = 10000 evaluator = ColorHierarchyEvaluator(device=device, directory=directory) evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)