diff --git "a/evaluation/sec533_clip_nn_accuracy.py" "b/evaluation/sec533_clip_nn_accuracy.py" --- "a/evaluation/sec533_clip_nn_accuracy.py" +++ "b/evaluation/sec533_clip_nn_accuracy.py" @@ -1,9 +1,9 @@ """ -§5.3.3 Nearest-Neighbour Classification Accuracy (Table 3) -============================================================ +Section 5.3.3 Nearest-Neighbour Classification Accuracy (Table 3) +================================================================== Evaluates the full GAP-CLIP embedding on three datasets and compares with the -patrickjohncyh/fashion-clip baseline: +patrickjohncyh/fashion-clip baseline — **color and hierarchy**. - Fashion-MNIST (public benchmark, 10 clothing categories) - KAGL Marqo HuggingFace dataset (diverse fashion, colour + category labels) @@ -11,16 +11,20 @@ patrickjohncyh/fashion-clip baseline: For each dataset the ``ColorHierarchyEvaluator`` class extracts: -* **Color slice** (dims 0–15): nearest-neighbour and centroid accuracy per colour class. -* **Hierarchy slice** (dims 16–79): nearest-neighbour and centroid accuracy per category. -* **Ensemble mode** (Kaggle/MNIST): sliced dims combined with full 512-D embedding. +* **Color slice** (dims 0–15): nearest-neighbour accuracy per colour class. +* **Hierarchy slice** (dims 16–79): nearest-neighbour accuracy per category, + plus 64-D vs 512-D comparison and image + text-prototype ensemble. Results feed directly into **Table 3** of the paper. +The hierarchy mapping for Kaggle uses the same approach as in +``sec52_category_model_eval.py`` (exact match -> substring -> fuzzy on +``category2``). + See also: - - §5.1 (``sec51_color_model_eval.py``) – standalone colour model - - §5.2 (``sec52_category_model_eval.py``) – confusion-matrix analysis - - §5.3.4–5 (``sec5354_separation_semantic.py``) – separation scores + - Section 5.1 (``sec51_color_model_eval.py``) – standalone colour model + - Section 5.2 (``sec52_category_model_eval.py``) – confusion-matrix analysis + - Section 5.3.6 (``sec536_embedding_structure.py``) – embedding-structure validation """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -31,1449 +35,940 @@ import warnings import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns import torch from collections import defaultdict from io import BytesIO from PIL import Image -from sklearn.metrics import accuracy_score, classification_report, confusion_matrix +from sklearn.metrics import accuracy_score, classification_report from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import normalize from torch.utils.data import DataLoader, Dataset from torchvision import transforms -from tqdm import tqdm -from transformers import CLIPModel as CLIPModel_transformers, CLIPProcessor warnings.filterwarnings('ignore') from config import ( + ROOT_DIR, color_emb_dim, column_local_image_path, hierarchy_emb_dim, hierarchy_model_path, local_dataset_path, + main_emb_dim, main_model_path, ) from utils.datasets import ( FashionMNISTDataset, LocalDataset, + collate_fn_filter_none, load_fashion_mnist_dataset, load_local_validation_dataset, ) +from utils.embeddings import extract_clip_embeddings +from utils.metrics import ( + compute_similarity_metrics, + compute_centroid_accuracy, + predict_labels_from_embeddings, + create_confusion_matrix, +) +from utils.model_loader import load_gap_clip, load_baseline_fashion_clip +from training.hierarchy_model import HierarchyExtractor + +# --------------------------------------------------------------------------- +# Hierarchy label normalisation (same as sec536_embedding_structure.py) +# Maps long internal taxonomy strings -> clean labels like "top", "pant", etc. +# --------------------------------------------------------------------------- +NORMALIZED_HIERARCHY_CLASSES = [ + "accessories", "bodysuits", "bras", "coat", "dress", "jacket", + "legging", "pant", "polo", "shirt", "shoes", "short", "skirt", + "socks", "sweater", "swimwear", "top", "underwear", +] + +_HIERARCHY_EXTRACTOR = HierarchyExtractor(NORMALIZED_HIERARCHY_CLASSES, verbose=False) + +_SYNONYMS = { + "t-shirt/top": "top", "top": "top", "tee": "top", "t-shirt": "top", + "shirt": "shirt", "shirts": "shirt", + "pullover": "sweater", "sweater": "sweater", + "coat": "coat", "jacket": "jacket", "outerwear": "coat", "outer": "coat", + "trouser": "pant", "trousers": "pant", "pants": "pant", "pant": "pant", "jeans": "pant", + "dress": "dress", "skirt": "skirt", + "shorts": "short", "short": "short", + "sandal": "shoes", "sneaker": "shoes", "ankle boot": "shoes", + "shoe": "shoes", "shoes": "shoes", "flip flops": "shoes", + "footwear": "shoes", "shoe accessories": "shoes", "boots": "shoes", + "bag": "accessories", "bags": "accessories", + "accessory": "accessories", "accessories": "accessories", + "belts": "accessories", "eyewear": "accessories", + "jewellery": "accessories", "jewelry": "accessories", + "headwear": "accessories", "wallets": "accessories", + "watches": "accessories", "mufflers": "accessories", + "scarves": "accessories", "stoles": "accessories", + "ties": "accessories", "sunglasses": "accessories", + "scarf & tie": "accessories", "scarf/tie": "accessories", "belt": "accessories", + "topwear": "top", "bottomwear": "pant", + "innerwear": "underwear", "loungewear and nightwear": "underwear", + "saree": "dress", +} + +_EXTRA_KEYWORDS = [ + ("capri", "pant"), ("denim", "pant"), ("skinny", "pant"), + ("boyfriend", "pant"), ("graphic", "top"), ("longsleeve", "top"), + ("leather", "jacket"), +] + + +def normalize_hierarchy_label(raw_label: str) -> str: + """Map any hierarchy string to a clean normalised label.""" + label = str(raw_label).strip().lower() + exact = _SYNONYMS.get(label) + if exact is not None: + return exact + result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label) + if result: + return result + for keyword, category in _EXTRA_KEYWORDS: + if keyword in label: + return category + return label + + +# ============================================================================ +# 1. Dataset utilities (hierarchy mapping matches sec52) +# ============================================================================ + +class KaggleHierarchyDataset(Dataset): + """KAGL Marqo dataset returning (image, description, color, hierarchy).""" + def __init__(self, dataframe, image_size=224): + self.dataframe = dataframe.reset_index(drop=True) + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + def __len__(self): + return len(self.dataframe) + def __getitem__(self, idx): + row = self.dataframe.iloc[idx] + image_data = row["image"] + if isinstance(image_data, dict) and "bytes" in image_data: + image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") + elif hasattr(image_data, "convert"): + image = image_data.convert("RGB") + else: + image = Image.open(BytesIO(image_data)).convert("RGB") + image = self.transform(image) + description = str(row["text"]) + color = str(row.get("baseColour", "unknown")).lower() + hierarchy = str(row["hierarchy"]) + return image, description, color, hierarchy -def create_kaggle_marqo_to_hierarchy_mapping(kaggle_labels, hierarchy_classes): - """Create mapping from Kaggle Marqo categories to hierarchy classes""" - hierarchy_classes = list(hierarchy_classes) - hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] - - synonyms = { - 'topwear': 'top', - 'tops': 'top', - 'tee': 'top', - 'tees': 'top', - 't-shirt': 'top', - 'tshirt': 'top', - 'tshirts': 'top', - 'shirt': 'shirt', - 'shirts': 'shirt', - 'sweater': 'sweater', - 'sweaters': 'sweater', - 'outerwear': 'coat', - 'outer': 'coat', - 'coat': 'coat', - 'coats': 'coat', - 'jacket': 'coat', - 'jackets': 'coat', - 'blazer': 'coat', - 'blazers': 'coat', - 'hoodie': 'hoodie', - 'hoodies': 'hoodie', - 'bottomwear': 'bottom', - 'bottoms': 'bottom', - 'pants': 'bottom', - 'pant': 'bottom', - 'trouser': 'bottom', - 'trousers': 'bottom', - 'jeans': 'jeans', - 'denim': 'jeans', - 'short': 'shorts', - 'shorts': 'shorts', - 'skirt': 'skirt', - 'skirts': 'skirt', - 'dress': 'dress', - 'dresses': 'dress', - 'saree': 'saree', - 'lehenga': 'lehenga', - 'shoe': 'shoes', - 'shoes': 'shoes', - 'sandal': 'shoes', - 'sandals': 'shoes', - 'sneaker': 'shoes', - 'sneakers': 'shoes', - 'boot': 'shoes', - 'boots': 'shoes', - 'heel': 'shoes', - 'heels': 'shoes', - 'flip flops': 'shoes', - 'flip-flops': 'shoes', - 'loafer': 'shoes', - 'loafers': 'shoes', - 'bag': 'bag', - 'bags': 'bag', - 'backpack': 'bag', - 'backpacks': 'bag', - 'handbag': 'bag', - 'handbags': 'bag', - 'accessory': 'accessories', - 'accessories': 'accessories', - 'belt': 'belt', - 'belts': 'belt', - 'scarf': 'scarf', - 'scarves': 'scarf', - 'cap': 'cap', - 'caps': 'cap', - 'hat': 'cap', - 'hats': 'cap', - 'watch': 'watch', - 'watches': 'watch', - } - - def match_candidate(candidate): - if candidate in hierarchy_classes_lower: - return hierarchy_classes[hierarchy_classes_lower.index(candidate)] - return None - - mapping = {} - - for label in sorted(set(kaggle_labels)): - label_str = str(label) if not pd.isna(label) else '' - label_lower = label_str.strip().lower() - matched_hierarchy = None - - if not label_lower: - mapping[label_lower] = None - continue - - # Direct match or synonym substitution - candidate = synonyms.get(label_lower, label_lower) - matched_hierarchy = match_candidate(candidate) - - # Partial match with hierarchy classes - if matched_hierarchy is None: - for idx, h_lower in enumerate(hierarchy_classes_lower): - if h_lower in candidate or candidate in h_lower: - matched_hierarchy = hierarchy_classes[idx] - break - - # Token-based match (split on spaces, hyphens, slashes) - if matched_hierarchy is None: - tokens = set(candidate.replace('-', ' ').replace('/', ' ').split()) - for token in tokens: - token_candidate = synonyms.get(token, token) - matched_hierarchy = match_candidate(token_candidate) - if matched_hierarchy: - break - - # Synonym containment checks - if matched_hierarchy is None: - for synonym_key, synonym_value in synonyms.items(): - if synonym_key in candidate: - matched_hierarchy = match_candidate(synonym_value) - if matched_hierarchy: + +def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): + """Load KAGL Marqo dataset with hierarchy labels derived from category2. + + Mapping: exact match -> substring match -> fuzzy match (same as sec52). + """ + if raw_df is not None: + df = raw_df.copy() + print(f"Using cached KAGL DataFrame: {len(df)} samples") + else: + from datasets import load_dataset + print("Loading KAGL Marqo dataset...") + dataset = load_dataset("Marqo/KAGL") + df = dataset["data"].to_pandas() + print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}") + + hierarchy_col = 'category2' + print(f"Using '{hierarchy_col}' as hierarchy source") + df = df.dropna(subset=["text", "image", hierarchy_col]) + df["hierarchy"] = df[hierarchy_col].astype(str).str.strip() + + # Normalise every category2 value through the synonym/extractor pipeline + df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label) + + if hierarchy_classes: + hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] + mapped = [] + for _, row in df.iterrows(): + kagl_type = row["hierarchy"].lower() + matched = None + # Exact match (after normalisation most will hit here) + if kagl_type in hierarchy_classes_lower: + matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)] + else: + # Substring match + for h_class in hierarchy_classes: + h_lower = h_class.lower() + if h_lower in kagl_type or kagl_type in h_lower: + matched = h_class break - - # Fallback to fuzzy matching - if matched_hierarchy is None: - close_matches = difflib.get_close_matches(candidate, hierarchy_classes_lower, n=1, cutoff=0.6) - if close_matches: - matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])] - - mapping[label_lower] = matched_hierarchy - - if matched_hierarchy: - print(f" {label_str} -> {matched_hierarchy}") - else: - print(f" ⚠️ {label_str} -> NO MATCH (will be filtered out)") - - return mapping + if matched is None: + close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6) + if close: + matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])] + mapped.append(matched) + df["hierarchy"] = mapped + df = df.dropna(subset=["hierarchy"]) + print(f"After hierarchy mapping: {len(df)} samples") + + # Normalise color column + if "baseColour" in df.columns: + df["baseColour"] = df["baseColour"].fillna("unknown").astype(str).str.lower().str.replace("grey", "gray") + else: + df["baseColour"] = "unknown" + + df = df.dropna(subset=["text", "image"]) + + if len(df) > max_samples: + df = df.sample(n=max_samples, random_state=42) + + print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: " + f"{sorted(df['hierarchy'].unique())}") + return KaggleHierarchyDataset(df) + +class LocalHierarchyDataset(Dataset): + """Local validation dataset returning (image, description, color, hierarchy).""" -class KaggleDataset(Dataset): - """Dataset class for KAGL Marqo dataset""" def __init__(self, dataframe, image_size=224): - self.dataframe = dataframe - self.image_size = image_size - - # Transforms for validation (no augmentation) - self.val_transform = transforms.Compose([ + self.dataframe = dataframe.reset_index(drop=True) + self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) - + def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] - - # Handle image - it should be in row['image_url'] and contain the image data as bytes - image_data = row['image_url'] - - # Check if image_data has 'bytes' key or is already PIL Image - if isinstance(image_data, dict) and 'bytes' in image_data: - image = Image.open(BytesIO(image_data['bytes'])).convert("RGB") - elif hasattr(image_data, 'convert'): # Already a PIL Image - image = image_data.convert("RGB") - else: - # Assume it's raw bytes - image = Image.open(BytesIO(image_data)).convert("RGB") - - # Apply validation transform - image = self.val_transform(image) + try: + img_path = row[column_local_image_path] + if not os.path.isabs(img_path): + img_path = os.path.join(ROOT_DIR, img_path) + image = Image.open(img_path).convert("RGB") + except Exception: + image = Image.new("RGB", (224, 224), color="gray") + image = self.transform(image) + description = str(row["text"]) + color = str(row.get("color", "unknown")) + hierarchy = str(row["hierarchy"]) + return image, description, color, hierarchy - # Get text and labels - description = row['text'] - color = row.get('color', 'unknown') - hierarchy = row['hierarchy'] - return image, description, color, hierarchy +def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): + """Load internal validation dataset with hierarchy labels.""" + if raw_df is not None: + df = raw_df.copy() + print(f"Using cached local DataFrame: {len(df)} samples") + else: + print("Loading local validation dataset...") + df = pd.read_csv(local_dataset_path) + print(f"Dataset loaded: {len(df)} samples") + df = df.dropna(subset=[column_local_image_path, "hierarchy"]) + df["hierarchy"] = df["hierarchy"].astype(str).str.strip() + df = df[df["hierarchy"].str.len() > 0] -def load_kaggle_marqo_dataset(evaluator, max_samples=10000): - """Load and prepare Kaggle KAGL dataset with memory optimization""" - from datasets import load_dataset - print("📊 Loading Kaggle KAGL dataset...") - - # Load the dataset - dataset = load_dataset("Marqo/KAGL") - df = dataset["data"].to_pandas() - print(f"✅ Dataset Kaggle loaded") - print(f" Before filtering: {len(df)} samples") - print(f" Available columns: {list(df.columns)}") - - # Check available categories and create mapping to validation hierarchies - available_categories = sorted(df['category2'].dropna().unique()) - print(f"🎨 Available categories: {available_categories}") - - validation_hierarchies = evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes - print(f"📚 Validation hierarchies: {sorted(validation_hierarchies)}") - - print("\n🔗 Creating mapping from Kaggle categories to validation hierarchies:") - category_mapping = create_kaggle_marqo_to_hierarchy_mapping(available_categories, validation_hierarchies) - - total_categories = {str(cat).strip().lower() for cat in df['category2'].dropna()} - unmapped_categories = sorted(cat for cat in total_categories if category_mapping.get(cat) is None) - if unmapped_categories: - print(f"⚠️ Categories without mapping (will be dropped): {unmapped_categories}") - - df['hierarchy'] = df['category2'].apply( - lambda cat: category_mapping.get(str(cat).strip().lower()) if pd.notna(cat) else None - ) - - before_mapping_len = len(df) - df = df[df['hierarchy'].notna()] - print(f" After mapping to validation hierarchies: {len(df)} samples (from {before_mapping_len})") - - if len(df) == 0: - print("❌ No samples left after hierarchy mapping.") - return None - - # Ensure we have text and image data - df = df.dropna(subset=['text', 'image']) - print(f" After removing missing text/image: {len(df)} samples") - - # Show sample of text data to verify quality - print(f"📝 Sample texts:") - for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))): - print(f" {i+1}. [{hierarchy}] {text[:100]}...") - - df_test = df.copy() - - # Limit to max_samples - if len(df_test) > max_samples: - df_test = df_test.head(max_samples) - - print(f"📊 After sampling: {len(df_test)} samples") - print(f" Samples per hierarchy:") - for hierarchy in sorted(df_test['hierarchy'].unique()): - count = len(df_test[df_test['hierarchy'] == hierarchy]) - print(f" {hierarchy}: {count} samples") - - # Create formatted dataset with proper column names - kaggle_formatted = pd.DataFrame({ - 'image_url': df_test['image'], # This contains image data as bytes - 'text': df_test['text'], - 'hierarchy': df_test['hierarchy'], - 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors - }) - - print(f" Final dataset size: {len(kaggle_formatted)} samples") - return KaggleDataset(kaggle_formatted) + # Normalise raw taxonomy strings to clean labels + df["hierarchy"] = df["hierarchy"].apply(normalize_hierarchy_label) + if hierarchy_classes: + hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] + df["hierarchy_lower"] = df["hierarchy"].str.lower() + df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)] + case_map = {h.lower(): h for h in hierarchy_classes} + df["hierarchy"] = df["hierarchy_lower"].map(case_map) + df = df.drop(columns=["hierarchy_lower"]) + print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes") + if len(df) > max_samples: + df = df.sample(n=max_samples, random_state=42) -class ColorHierarchyEvaluator: - """Evaluate color (dims 0-15) and hierarchy (dims 16-79) embeddings on Fashion-MNIST""" + print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}") + return LocalHierarchyDataset(df) + + +# ============================================================================ +# 2. Evaluator +# ============================================================================ - def __init__(self, device='mps', directory='fashion_mnist_color_hierarchy_analysis'): - self.device = torch.device(device) +class ColorHierarchyEvaluator: + """ + Evaluates color and hierarchy NN classification accuracy for GAP-CLIP + and the baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and the + internal validation dataset. + """ + + def __init__(self, device='mps', directory='main_model_analysis', + gap_clip_model=None, gap_clip_processor=None, + baseline_model=None, baseline_processor=None, + hierarchy_classes=None, + kaggle_raw_df=None, local_raw_df=None): + self.device = torch.device(device) if isinstance(device, str) else device self.directory = directory + self.kaggle_raw_df = kaggle_raw_df + self.local_raw_df = local_raw_df self.color_emb_dim = color_emb_dim self.hierarchy_emb_dim = hierarchy_emb_dim + self.main_emb_dim = main_emb_dim + self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim os.makedirs(self.directory, exist_ok=True) - print(f"🚀 Loading main model from {main_model_path}") - if not os.path.exists(main_model_path): - raise FileNotFoundError(f"Main model file {main_model_path} not found") - - # Load hierarchy classes from hierarchy model checkpoint - print("📋 Loading hierarchy classes from hierarchy model...") - if not os.path.exists(hierarchy_model_path): - raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") - - hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) - self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) - print(f"✅ Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") + # --- hierarchy classes --- + if hierarchy_classes is not None: + self.hierarchy_classes = hierarchy_classes + print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes") + else: + print("Loading hierarchy classes from hierarchy model...") + if not os.path.exists(hierarchy_model_path): + raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") + hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) + self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) + print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() if self.validation_hierarchy_classes: - print(f"✅ Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): {sorted(self.validation_hierarchy_classes)}") + print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): " + f"{sorted(self.validation_hierarchy_classes)}") else: - print("⚠️ Unable to load validation hierarchy classes, falling back to hierarchy model classes.") + print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.") self.validation_hierarchy_classes = self.hierarchy_classes - checkpoint = torch.load(main_model_path, map_location=self.device) - self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') - self.model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') - self.model.load_state_dict(checkpoint['model_state_dict']) - self.model.to(self.device) - self.model.eval() - print("✅ Main model loaded successfully") - - # Load baseline Fashion CLIP model - print("📦 Loading baseline Fashion CLIP model...") - patrick_model_name = "patrickjohncyh/fashion-clip" - self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name) - self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device) - self.baseline_model.eval() - print("✅ Baseline Fashion CLIP model loaded successfully") + # --- load GAP-CLIP --- + if gap_clip_model is not None and gap_clip_processor is not None: + self.model = gap_clip_model + self.processor = gap_clip_processor + print("Using pre-loaded GAP-CLIP model") + else: + self.model, self.processor = load_gap_clip(main_model_path, self.device) + print("GAP-CLIP model loaded successfully") + + # --- baseline Fashion-CLIP --- + if baseline_model is not None and baseline_processor is not None: + self.baseline_model = baseline_model + self.baseline_processor = baseline_processor + print("Using pre-loaded baseline Fashion-CLIP model") + else: + self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) + print("Baseline Fashion-CLIP model loaded successfully") + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ def _load_validation_hierarchy_classes(self): - """Load hierarchy classes present in the validation dataset""" + """Load hierarchy classes from local CSV, normalised to clean labels.""" if not os.path.exists(local_dataset_path): - print(f"⚠️ Validation dataset not found at {local_dataset_path}") + print(f"Validation dataset not found at {local_dataset_path}") return [] - try: df = pd.read_csv(local_dataset_path) except Exception as exc: - print(f"⚠️ Failed to read validation dataset: {exc}") + print(f"Failed to read validation dataset: {exc}") return [] - if 'hierarchy' not in df.columns: - print("⚠️ Validation dataset does not contain 'hierarchy' column.") + print("Validation dataset does not contain 'hierarchy' column.") return [] - - hierarchies = ( - df['hierarchy'] - .dropna() - .astype(str) - .str.strip() - ) - hierarchies = [h for h in hierarchies if h] - return sorted(set(hierarchies)) - - def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000): - """Extract color embeddings from dims 0-15 (16 dimensions)""" - all_embeddings = [] - all_colors = [] - all_hierarchies = [] - - sample_count = 0 - with torch.no_grad(): - for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-15)"): - if sample_count >= max_samples: - break - - images, texts, colors, hierarchies = batch - images = images.to(self.device) - images = images.expand(-1, 3, -1, -1) - - text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") - text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} - - outputs = self.model(**text_inputs, pixel_values=images) - - if embedding_type == 'text': - embeddings = outputs.text_embeds - elif embedding_type == 'image': - embeddings = outputs.image_embeds - else: - embeddings = outputs.text_embeds - - # Extract only color embeddings (dims 0-15, i.e., first 16 dimensions) - # color_embeddings = embeddings[:, :self.color_emb_dim] - - color_embeddings = embeddings - all_embeddings.append(color_embeddings.cpu().numpy()) - all_colors.extend(colors) - all_hierarchies.extend(hierarchies) - - sample_count += len(images) - - del images, text_inputs, outputs, embeddings, color_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - return np.vstack(all_embeddings), all_colors, all_hierarchies - - def extract_hierarchy_embeddings(self, dataloader, embedding_type='text', max_samples=10000): - """Extract hierarchy embeddings from dims 16-79 (indices 16:79)""" - all_embeddings = [] - all_colors = [] - all_hierarchies = [] - - sample_count = 0 - with torch.no_grad(): - for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} hierarchy embeddings (dims 16-79)"): - if sample_count >= max_samples: - break - - images, texts, colors, hierarchies = batch - images = images.to(self.device) - images = images.expand(-1, 3, -1, -1) - - text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") - text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} - - outputs = self.model(**text_inputs, pixel_values=images) - - if embedding_type == 'text': - embeddings = outputs.text_embeds - elif embedding_type == 'image': - embeddings = outputs.image_embeds - else: - embeddings = outputs.text_embeds - - # Extract hierarchy embeddings (dims 17-79 -> indices 16:79) - # hierarchy_embeddings = embeddings[:, 16:79] - - hierarchy_embeddings = embeddings - all_embeddings.append(hierarchy_embeddings.cpu().numpy()) - all_colors.extend(colors) - all_hierarchies.extend(hierarchies) - - sample_count += len(images) - - del images, text_inputs, outputs, embeddings, hierarchy_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - return np.vstack(all_embeddings), all_colors, all_hierarchies + raw = df['hierarchy'].dropna().astype(str).str.strip() + normalised = sorted(set( + normalize_hierarchy_label(h) for h in raw if h + )) + # Keep only labels that belong to the known set + normalised = [h for h in normalised if h in NORMALIZED_HIERARCHY_CLASSES] + print(f"Normalised validation hierarchy classes: {normalised}") + return normalised + + def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8): + """Build one shared Fashion-MNIST dataset/dataloader. + + Uses NORMALIZED_HIERARCHY_CLASSES so that Fashion-MNIST labels are + mapped to clean short names (top, pant, shoes, sweater, coat, …). + """ + target_classes = self.validation_hierarchy_classes or NORMALIZED_HIERARCHY_CLASSES + fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes) + + # Normalise whatever label_mapping produced (e.g. "Coat" -> "coat") + if fashion_dataset.label_mapping: + fashion_dataset.label_mapping = { + k: normalize_hierarchy_label(v) if v else v + for k, v in fashion_dataset.label_mapping.items() + } + + dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0) + + hierarchy_counts = defaultdict(int) + if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: + for _, row in fashion_dataset.dataframe.iterrows(): + lid = int(row['label']) + hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1 + + return fashion_dataset, dataloader, dict(hierarchy_counts) + + @staticmethod + def _count_labels(labels): + counts = defaultdict(int) + for label in labels: + counts[label] += 1 + return dict(counts) + + def _validate_label_distribution(self, labels, expected_counts, context): + observed = self._count_labels(labels) + if observed != expected_counts: + raise ValueError( + f"Label distribution mismatch in {context}. " + f"Expected {expected_counts}, observed {observed}" + ) + # ------------------------------------------------------------------ + # embedding extraction + # ------------------------------------------------------------------ def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): - """Extract full 512-dimensional embeddings (all dimensions)""" - all_embeddings = [] - all_colors = [] - all_hierarchies = [] - - sample_count = 0 - with torch.no_grad(): - for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"): - if sample_count >= max_samples: - break - - images, texts, colors, hierarchies = batch - images = images.to(self.device) - images = images.expand(-1, 3, -1, -1) - - text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") - text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} - - outputs = self.model(**text_inputs, pixel_values=images) - - if embedding_type == 'text': - embeddings = outputs.text_embeds - elif embedding_type == 'image': - embeddings = outputs.image_embeds - else: - embeddings = outputs.text_embeds - - # Use all 512 dimensions - all_embeddings.append(embeddings.cpu().numpy()) - all_colors.extend(colors) - all_hierarchies.extend(hierarchies) - - sample_count += len(images) - - del images, text_inputs, outputs, embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - return np.vstack(all_embeddings), all_colors, all_hierarchies + """Full 512D embeddings from GAP-CLIP.""" + return extract_clip_embeddings( + self.model, self.processor, dataloader, self.device, + embedding_type=embedding_type, max_samples=max_samples, + desc=f"GAP-CLIP {embedding_type} embeddings", + ) def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): - """ - Extract embeddings from baseline Fashion CLIP model. - - This method properly processes images and text through the Fashion-CLIP processor - and applies L2 normalization to embeddings, matching the evaluation in evaluate_color_embeddings.py - """ - all_embeddings = [] - all_colors = [] - all_hierarchies = [] - - sample_count = 0 - - with torch.no_grad(): - for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"): - if sample_count >= max_samples: - break - - images, texts, colors, hierarchies = batch - - # Extract embeddings based on type - if embedding_type == 'text': - # Process text through Fashion-CLIP processor - text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77) - text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} - - # Get text features using the dedicated method - text_features = self.baseline_model.get_text_features(**text_inputs) - - # Apply L2 normalization (critical for CLIP!) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - embeddings = text_features - - elif embedding_type == 'image': - # Convert tensor images back to PIL Images for proper processing - pil_images = [] - for i in range(images.shape[0]): - img_tensor = images[i] - - # Denormalize if the images were normalized (undo ImageNet normalization) - # Check if images are normalized (values outside [0,1]) - if img_tensor.min() < 0 or img_tensor.max() > 1: - # Undo ImageNet normalization - mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) - std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) - img_tensor = img_tensor * std + mean - img_tensor = torch.clamp(img_tensor, 0, 1) - - # Convert to PIL Image - img_pil = transforms.ToPILImage()(img_tensor) - pil_images.append(img_pil) - - # Process images through Fashion-CLIP processor (will apply its own normalization) - image_inputs = self.baseline_processor(images=pil_images, return_tensors="pt") - image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} - - # Get image features using the dedicated method - image_features = self.baseline_model.get_image_features(**image_inputs) - - # Apply L2 normalization (critical for CLIP!) - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - embeddings = image_features - - else: - # Default to text - text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77) - text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} - text_features = self.baseline_model.get_text_features(**text_inputs) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - embeddings = text_features - - all_embeddings.append(embeddings.cpu().numpy()) - all_colors.extend(colors) - all_hierarchies.extend(hierarchies) - - sample_count += len(images) - - # Clear GPU memory - del embeddings - if embedding_type == 'image': - del pil_images, image_inputs - else: - del text_inputs - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - return np.vstack(all_embeddings), all_colors, all_hierarchies - - def compute_similarity_metrics(self, embeddings, labels): - """Compute intra-class and inter-class similarities""" - max_samples = min(5000, len(embeddings)) - if len(embeddings) > max_samples: - indices = np.random.choice(len(embeddings), max_samples, replace=False) - embeddings = embeddings[indices] - labels = [labels[i] for i in indices] + """L2-normalised embeddings from baseline Fashion-CLIP.""" + return extract_clip_embeddings( + self.baseline_model, self.baseline_processor, dataloader, self.device, + embedding_type=embedding_type, max_samples=max_samples, + desc=f"Baseline {embedding_type} embeddings", + ) + # ------------------------------------------------------------------ + # prediction methods + # ------------------------------------------------------------------ + def predict_labels_nearest_neighbor(self, embeddings, labels): + """Predict labels using 1-NN on the same embedding set.""" similarities = cosine_similarity(embeddings) - - label_groups = defaultdict(list) - for i, label in enumerate(labels): - label_groups[label].append(i) - - intra_class_similarities = [] - for label, indices in label_groups.items(): - if len(indices) > 1: - for i in range(len(indices)): - for j in range(i + 1, len(indices)): - sim = similarities[indices[i], indices[j]] - intra_class_similarities.append(sim) - - inter_class_similarities = [] - labels_list = list(label_groups.keys()) - for i in range(len(labels_list)): - for j in range(i + 1, len(labels_list)): - label1_indices = label_groups[labels_list[i]] - label2_indices = label_groups[labels_list[j]] - for idx1 in label1_indices: - for idx2 in label2_indices: - sim = similarities[idx1, idx2] - inter_class_similarities.append(sim) - - nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) - centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) - - return { - 'intra_class_similarities': intra_class_similarities, - 'inter_class_similarities': inter_class_similarities, - 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, - 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, - 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0, - 'accuracy': nn_accuracy, - 'centroid_accuracy': centroid_accuracy, - } - - def compute_embedding_accuracy(self, embeddings, labels, similarities): - """Compute classification accuracy using nearest neighbor""" - correct_predictions = 0 - total_predictions = len(labels) + preds = [] for i in range(len(embeddings)): - true_label = labels[i] - similarities_row = similarities[i].copy() - similarities_row[i] = -1 - nearest_neighbor_idx = int(np.argmax(similarities_row)) - predicted_label = labels[nearest_neighbor_idx] - if predicted_label == true_label: - correct_predictions += 1 - return correct_predictions / total_predictions if total_predictions > 0 else 0.0 - - def compute_centroid_accuracy(self, embeddings, labels): - """Compute classification accuracy using centroids""" - unique_labels = list(set(labels)) + sims = similarities[i].copy() + sims[i] = -1.0 + nearest_neighbor_idx = int(np.argmax(sims)) + preds.append(labels[nearest_neighbor_idx]) + return preds + + def _compute_img_centroids(self, embeddings, labels): + emb_norm = normalize(embeddings, norm='l2') centroids = {} - for label in unique_labels: - label_indices = [i for i, l in enumerate(labels) if l == label] - centroids[label] = np.mean(embeddings[label_indices], axis=0) - - correct_predictions = 0 - total_predictions = len(labels) - for i, embedding in enumerate(embeddings): - true_label = labels[i] - best_similarity = -1 - predicted_label = None - for label, centroid in centroids.items(): - similarity = cosine_similarity([embedding], [centroid])[0][0] - if similarity > best_similarity: - best_similarity = similarity - predicted_label = label - if predicted_label == true_label: - correct_predictions += 1 - return correct_predictions / total_predictions if total_predictions > 0 else 0.0 - - def predict_labels_from_embeddings(self, embeddings, labels): - """Predict labels from embeddings using centroid-based classification""" - unique_labels = list(set(labels)) - centroids = {} - for label in unique_labels: - label_indices = [i for i, l in enumerate(labels) if l == label] - centroids[label] = np.mean(embeddings[label_indices], axis=0) - - predictions = [] - for i, embedding in enumerate(embeddings): - best_similarity = -1 - predicted_label = None - for label, centroid in centroids.items(): - similarity = cosine_similarity([embedding], [centroid])[0][0] - if similarity > best_similarity: - best_similarity = similarity - predicted_label = label - predictions.append(predicted_label) - return predictions - - def predict_labels_ensemble(self, specialized_embeddings, full_embeddings, labels, - specialized_weight=0.5): - """ - Ensemble prediction combining specialized (16/64 dims) and full (512 dims) embeddings. - - Args: - specialized_embeddings: Embeddings from specialized dimensions (e.g., dims 0-15 for color) - full_embeddings: Full 512-dimensional embeddings - labels: True labels for computing centroids - specialized_weight: Weight for specialized embeddings (0.0 = only full, 1.0 = only specialized) - - Returns: - List of predicted labels using weighted ensemble - """ - unique_labels = list(set(labels)) - - # Compute centroids for both specialized and full embeddings - specialized_centroids = {} - full_centroids = {} - - for label in unique_labels: - label_indices = [i for i, l in enumerate(labels) if l == label] - specialized_centroids[label] = np.mean(specialized_embeddings[label_indices], axis=0) - full_centroids[label] = np.mean(full_embeddings[label_indices], axis=0) - - predictions = [] - for i in range(len(specialized_embeddings)): - best_combined_score = -np.inf - predicted_label = None - - for label in unique_labels: - # Compute similarity scores for both specialized and full - spec_sim = cosine_similarity([specialized_embeddings[i]], [specialized_centroids[label]])[0][0] - full_sim = cosine_similarity([full_embeddings[i]], [full_centroids[label]])[0][0] - - # Weighted combination - combined_score = specialized_weight * spec_sim + (1 - specialized_weight) * full_sim - - if combined_score > best_combined_score: - best_combined_score = combined_score - predicted_label = label - - predictions.append(predicted_label) - - return predictions - - def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"): - """Create and plot confusion matrix""" - unique_labels = sorted(list(set(true_labels + predicted_labels))) - cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) - accuracy = accuracy_score(true_labels, predicted_labels) - plt.figure(figsize=(12, 10)) - sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels) - plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') - plt.ylabel(f'True {label_type}') - plt.xlabel(f'Predicted {label_type}') - plt.xticks(rotation=45) - plt.yticks(rotation=0) - plt.tight_layout() - return plt.gcf(), accuracy, cm - - - def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label", - full_embeddings=None, ensemble_weight=0.5): - """ - Evaluate classification performance and create confusion matrix. - - Args: - embeddings: Specialized embeddings (e.g., dims 0-15 for color or dims 16-79 for hierarchy) - labels: True labels - embedding_type: Type of embeddings for display - label_type: Type of labels (Color/Hierarchy) - full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only specialized) - ensemble_weight: Weight for specialized embeddings in ensemble (0.0 = only full, 1.0 = only specialized) - """ - if full_embeddings is not None: - # Use ensemble prediction - predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight) + for label in sorted(set(labels)): + idx = [i for i, l in enumerate(labels) if l == label] + centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0] + return centroids + + def predict_labels_image_ensemble(self, img_embeddings, labels, + text_protos, cls_names, alpha=0.5): + """Combine image centroids (512D) with text prototypes (512D).""" + img_norm = normalize(img_embeddings, norm='l2') + img_centroids = self._compute_img_centroids(img_norm, labels) + centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0) + + preds = [] + for i in range(len(img_norm)): + v = img_norm[i:i + 1] + sim_img = cosine_similarity(v, centroid_mat)[0] + sim_txt = cosine_similarity(v, text_protos)[0] + scores = alpha * sim_img + (1 - alpha) * sim_txt + preds.append(cls_names[int(np.argmax(scores))]) + return preds + + # ------------------------------------------------------------------ + # classification evaluation + # ------------------------------------------------------------------ + def evaluate_classification_performance(self, embeddings, labels, + embedding_type="Embeddings", + label_type="Hierarchy", + method="nn"): + if method == "nn": + preds = self.predict_labels_nearest_neighbor(embeddings, labels) + elif method == "centroid": + preds = predict_labels_from_embeddings(embeddings, labels) else: - # Use only specialized embeddings - predictions = self.predict_labels_from_embeddings(embeddings, labels) - - accuracy = accuracy_score(labels, predictions) - fig, acc, cm = self.create_confusion_matrix( - labels, predictions, - f"{label_type} Classification", - label_type + raise ValueError(f"Unknown classification method: {method}") + acc = accuracy_score(labels, preds) + unique_labels = sorted(set(labels)) + fig, _, cm = create_confusion_matrix( + labels, preds, + f"{embedding_type} - {label_type} Classification ({method.upper()})", + label_type, ) - unique_labels = sorted(list(set(labels))) - report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) + report = classification_report(labels, preds, labels=unique_labels, + target_names=unique_labels, output_dict=True) return { - 'accuracy': accuracy, - 'predictions': predictions, + 'accuracy': acc, + 'predictions': preds, 'confusion_matrix': cm, + 'labels': unique_labels, 'classification_report': report, 'figure': fig, } - def evaluate_fashion_mnist(self, max_samples): - """Evaluate both color and hierarchy embeddings on Fashion-MNIST""" - print(f"\n{'='*60}") - print("Evaluating Fashion-MNIST") - print(" Color embeddings: dims 0-15") - print(" Hierarchy embeddings: dims 16-79") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes - fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) - dataloader = DataLoader(fashion_dataset, batch_size=8, shuffle=False, num_workers=0) - - # Check hierarchy distribution after mapping - if len(fashion_dataset.dataframe) > 0: - print(f"\n📊 Hierarchy distribution in dataset:") - if fashion_dataset.label_mapping: - hierarchy_counts = {} - for _, row in fashion_dataset.dataframe.iterrows(): - label_id = int(row['label']) - hierarchy = fashion_dataset.label_mapping.get(label_id, 'unknown') - hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 - - for hierarchy, count in sorted(hierarchy_counts.items()): - print(f" {hierarchy}: {count} samples") + def save_confusion_matrix_table(self, cm, labels, output_csv_path): + cm_df = pd.DataFrame(cm, index=labels, columns=labels) + cm_df["row_total"] = cm_df.sum(axis=1) + cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()] + cm_df.to_csv(output_csv_path) + + # ================================================================== + # 3. GAP-CLIP evaluation on Fashion-MNIST (hierarchy only — no color) + # ================================================================== + def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): + print(f"\n{'=' * 60}") + print("Evaluating GAP-CLIP on Fashion-MNIST (Hierarchy only)") + print(f" Hierarchy embeddings (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})") + print(f" Max samples: {max_samples}") + print(f"{'=' * 60}") + + if dataloader is None: + fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) + expected_counts = expected_counts or dataset_counts + else: + if expected_counts is None: + raise ValueError("expected_counts must be provided when using a custom dataloader.") results = {} - # ========== EXTRACT FULL EMBEDDINGS FOR ENSEMBLE ========== - print("\n📦 Extracting full 512-dimensional embeddings for ensemble...") - text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples) - image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples) - print(f" Text full embeddings shape: {text_full_embeddings.shape}") - print(f" Image full embeddings shape: {image_full_embeddings.shape}") - - # ========== HIERARCHY EVALUATION (DIMS 16-79) WITH ENSEMBLE ========== - print("\n📋 HIERARCHY EVALUATION (dims 16-79) - Using Ensemble") - print("=" * 50) - - # Extract specialized hierarchy embeddings (dims 16-79) - print("\n📝 Extracting specialized text hierarchy embeddings (dims 16-79)...") - text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] - print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}") - text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full) - # Use ensemble: combine specialized (64D) + full (512D) - text_hierarchy_class = self.evaluate_classification_performance( - text_hierarchy_embeddings_spec, text_hierarchies_full, - "Text Hierarchy Embeddings (Ensemble)", "Hierarchy", - full_embeddings=text_full_embeddings, ensemble_weight=1 + # --- full 512D embeddings (text & image) --- + print("\nExtracting full 512-dimensional GAP-CLIP embeddings...") + text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) + img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) + self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text") + self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image") + print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}") + + # ===== HIERARCHY (dims 16-79) ===== + print(f"\n--- GAP-CLIP TEXT HIERARCHY (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1}) ---") + text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim] + print(f" Specialized text hierarchy shape: {text_hier_spec.shape}") + + text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier) + text_hier_class = self.evaluate_classification_performance( + text_hier_spec, text_hier, "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", method="nn", ) - text_hierarchy_metrics.update(text_hierarchy_class) - results['text_hierarchy'] = text_hierarchy_metrics - - # Image hierarchy embeddings with ensemble - print("\n🖼️ Extracting specialized image hierarchy embeddings (dims 16-79)...") - image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] - print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}") - image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full) - image_hierarchy_class = self.evaluate_classification_performance( - image_hierarchy_embeddings_spec, image_hierarchies_full, - "Image Hierarchy Embeddings (Ensemble)", "Hierarchy", - full_embeddings=image_full_embeddings, ensemble_weight=1 + text_hier_metrics.update(text_hier_class) + results['text_hierarchy'] = text_hier_metrics + + # IMAGE: 64D vs 512D + print(f"\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---") + img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim] + print(f" Specialized image hierarchy shape: {img_hier_spec.shape}") + + print(" Testing specialized 64D...") + spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) + spec_class = self.evaluate_classification_performance( + img_hier_spec, img_hier, "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", method="nn", ) - image_hierarchy_metrics.update(image_hierarchy_class) - results['image_hierarchy'] = image_hierarchy_metrics - # Cleanup - del text_full_embeddings, image_full_embeddings - del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec - torch.cuda.empty_cache() if torch.cuda.is_available() else None + print(" Testing full 512D...") + full_metrics = compute_similarity_metrics(img_full, img_hier) + full_class = self.evaluate_classification_performance( + img_full, img_hier, "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", method="nn", + ) - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) + if full_class['accuracy'] >= spec_class['accuracy']: + print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%") + img_hier_metrics, img_hier_class = full_metrics, full_class + else: + print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%") + img_hier_metrics, img_hier_class = spec_metrics, spec_class + + # ensemble image + text prototypes + print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...") + cls_names = sorted(set(img_hier)) + prompts = [f"a photo of a {c}" for c in cls_names] + text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) + text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} + with torch.no_grad(): + txt_feats = self.model.get_text_features(**text_inputs) + txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) + text_protos = txt_feats.cpu().numpy() + + ensemble_preds = self.predict_labels_image_ensemble( + img_full, img_hier, text_protos, cls_names, alpha=0.7, + ) + ensemble_acc = accuracy_score(img_hier, ensemble_preds) + print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%") + + img_hier_metrics.update(img_hier_class) + img_hier_metrics['ensemble_accuracy'] = ensemble_acc + results['image_hierarchy'] = img_hier_metrics + + # --- save confusion matrix figures --- for key in ['text_hierarchy', 'image_hierarchy']: - results[key]['figure'].savefig( - f"{self.directory}/fashion_{key.replace('_', '_')}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', + fig = results[key]['figure'] + fig.savefig( + os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"), + dpi=300, bbox_inches='tight', ) - plt.close(results[key]['figure']) + self.save_confusion_matrix_table( + results[key]['confusion_matrix'], + results[key]['labels'], + os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"), + ) + plt.close(fig) + + del text_full, img_full, text_hier_spec, img_hier_spec + if torch.cuda.is_available(): + torch.cuda.empty_cache() return results - def evaluate_kaggle_marqo(self, max_samples): - """Evaluate both color and hierarchy embeddings on KAGL Marqo dataset""" - print(f"\n{'='*60}") - print("Evaluating KAGL Marqo Dataset") - print(" Color embeddings: dims 0-15") - print(" Hierarchy embeddings: dims 16-79") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) - if kaggle_dataset is None: - print("❌ Failed to load KAGL dataset") - return None - - dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) - - # Check hierarchy distribution - if len(kaggle_dataset.dataframe) > 0: - print(f"\n📊 Hierarchy distribution in dataset:") - hierarchy_counts = {} - for _, row in kaggle_dataset.dataframe.iterrows(): - hierarchy = row['hierarchy'] - hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 - - for hierarchy, count in sorted(hierarchy_counts.items()): - print(f" {hierarchy}: {count} samples") + # ================================================================== + # 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST (hierarchy only) + # ================================================================== + def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): + print(f"\n{'=' * 60}") + print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST (Hierarchy only)") + print(f" Max samples: {max_samples}") + print(f"{'=' * 60}") + + if dataloader is None: + _, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) + expected_counts = expected_counts or dataset_counts + elif expected_counts is None: + raise ValueError("expected_counts must be provided when using a custom dataloader.") results = {} - # ========== EXTRACT FULL EMBEDDINGS FOR ENSEMBLE ========== - print("\n📦 Extracting full 512-dimensional embeddings for ensemble...") - text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples) - image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples) - print(f" Text full embeddings shape: {text_full_embeddings.shape}") - print(f" Image full embeddings shape: {image_full_embeddings.shape}") - - # ========== COLOR EVALUATION (DIMS 0-15) WITH ENSEMBLE ========== - print("\n🎨 COLOR EVALUATION (dims 0-15) - Using Ensemble") - print("=" * 50) - - # Extract specialized color embeddings (dims 0-15) - print("\n📝 Extracting specialized text color embeddings (dims 0-15)...") - text_color_embeddings_spec = text_full_embeddings[:, :self.color_emb_dim] # First 16 dims - print(f" Specialized text color embeddings shape: {text_color_embeddings_spec.shape}") - text_color_metrics = self.compute_similarity_metrics(text_color_embeddings_spec, text_colors_full) - # Use ensemble: combine specialized (16D) + full (512D) - text_color_class = self.evaluate_classification_performance( - text_color_embeddings_spec, text_colors_full, - "Text Color Embeddings (Ensemble)", "Color", - full_embeddings=text_full_embeddings, ensemble_weight=1 - ) - text_color_metrics.update(text_color_class) - results['text_color'] = text_color_metrics + # --- text --- + print("\nExtracting baseline text embeddings...") + text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) + self._validate_label_distribution(text_hier, expected_counts, "baseline text") + print(f" Baseline text shape: {text_emb.shape}") - # Image color embeddings with ensemble - print("\n🖼️ Extracting specialized image color embeddings (dims 0-15)...") - image_color_embeddings_spec = image_full_embeddings[:, :self.color_emb_dim] # First 16 dims - print(f" Specialized image color embeddings shape: {image_color_embeddings_spec.shape}") - image_color_metrics = self.compute_similarity_metrics(image_color_embeddings_spec, image_colors_full) - image_color_class = self.evaluate_classification_performance( - image_color_embeddings_spec, image_colors_full, - "Image Color Embeddings (Ensemble)", "Color", - full_embeddings=image_full_embeddings, ensemble_weight=1 # 40% specialized, 60% full - ) - image_color_metrics.update(image_color_class) - results['image_color'] = image_color_metrics - - # ========== HIERARCHY EVALUATION (DIMS 16-79) WITH ENSEMBLE ========== - print("\n📋 HIERARCHY EVALUATION (dims 16-79) - Using Ensemble") - print("=" * 50) - - # Extract specialized hierarchy embeddings (dims 16-79) - print("\n📝 Extracting specialized text hierarchy embeddings (dims 16-79)...") - text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79 - print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}") - text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full) - # Use ensemble: combine specialized (64D) + full (512D) - text_hierarchy_class = self.evaluate_classification_performance( - text_hierarchy_embeddings_spec, text_hierarchies_full, - "Text Hierarchy Embeddings (Ensemble)", "Hierarchy", - full_embeddings=text_full_embeddings, ensemble_weight=0.4 + text_metrics = compute_similarity_metrics(text_emb, text_hier) + text_class = self.evaluate_classification_performance( + text_emb, text_hier, "Baseline Text - Hierarchy", "Hierarchy", method="nn", ) - text_hierarchy_metrics.update(text_hierarchy_class) - results['text_hierarchy'] = text_hierarchy_metrics - - # Image hierarchy embeddings with ensemble - print("\n🖼️ Extracting specialized image hierarchy embeddings (dims 16-79)...") - image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] # dims 16-79 - print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}") - image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full) - image_hierarchy_class = self.evaluate_classification_performance( - image_hierarchy_embeddings_spec, image_hierarchies_full, - "Image Hierarchy Embeddings (Ensemble)", "Hierarchy", - full_embeddings=image_full_embeddings, ensemble_weight=0.4 + text_metrics.update(text_class) + results['text'] = {'hierarchy': text_metrics} + + del text_emb + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # --- image --- + print("\nExtracting baseline image embeddings...") + img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) + self._validate_label_distribution(img_hier, expected_counts, "baseline image") + print(f" Baseline image shape: {img_emb.shape}") + + img_metrics = compute_similarity_metrics(img_emb, img_hier) + img_class = self.evaluate_classification_performance( + img_emb, img_hier, "Baseline Image - Hierarchy", "Hierarchy", method="nn", ) - image_hierarchy_metrics.update(image_hierarchy_class) - results['image_hierarchy'] = image_hierarchy_metrics + img_metrics.update(img_class) + results['image'] = {'hierarchy': img_metrics} - # Cleanup - del text_full_embeddings, image_full_embeddings - del text_color_embeddings_spec, image_color_embeddings_spec - del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec - torch.cuda.empty_cache() if torch.cuda.is_available() else None + del img_emb + if torch.cuda.is_available(): + torch.cuda.empty_cache() - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) - for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: - results[key]['figure'].savefig( - f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', + for key in ['text', 'image']: + fig = results[key]['hierarchy']['figure'] + fig.savefig( + os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"), + dpi=300, bbox_inches='tight', ) - plt.close(results[key]['figure']) + self.save_confusion_matrix_table( + results[key]['hierarchy']['confusion_matrix'], + results[key]['hierarchy']['labels'], + os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"), + ) + plt.close(fig) return results - def evaluate_local_validation(self, max_samples): - """Evaluate both color and hierarchy embeddings on local validation dataset (NO ENSEMBLE - only specialized embeddings)""" - print(f"\n{'='*60}") - print("Evaluating Local Validation Dataset") - print(" Color embeddings: dims 0-15 (specialized only, no ensemble)") - print(" Hierarchy embeddings: dims 16-79 (specialized only, no ensemble)") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - local_dataset = load_local_validation_dataset(max_samples) - if local_dataset is None: - print("❌ Failed to load local validation dataset") - return None - - # Filter to only include hierarchies that exist in our model - if len(local_dataset.dataframe) > 0: - valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] - if len(valid_df) == 0: - print("❌ No samples left after hierarchy filtering.") - return None - if len(valid_df) < len(local_dataset.dataframe): - print(f"📊 Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") - local_dataset = LocalDataset(valid_df) - - dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) - - # Check hierarchy distribution - if len(local_dataset.dataframe) > 0: - print(f"\n📊 Hierarchy distribution in dataset:") - hierarchy_counts = {} - for _, row in local_dataset.dataframe.iterrows(): - hierarchy = row['hierarchy'] - hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 - - for hierarchy, count in sorted(hierarchy_counts.items()): - print(f" {hierarchy}: {count} samples") + # ================================================================== + # 5. Generic dataset evaluation (KAGL Marqo / Internal) + # ================================================================== + def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000): + """Evaluate GAP-CLIP color + hierarchy performance on any dataset.""" + print(f"\n{'=' * 60}") + print(f"Evaluating GAP-CLIP on {dataset_name} (Color + Hierarchy)") + print(f" Color (dims 0-{self.color_emb_dim - 1}) | " + f"Hierarchy (dims {self.color_emb_dim}-{self.hierarchy_end_dim - 1})") + print(f"{'=' * 60}") results = {} - # ========== COLOR EVALUATION (DIMS 0-15) - SPECIALIZED ONLY ========== - print("\n🎨 COLOR EVALUATION (dims 0-15) - Specialized embeddings only") - print("=" * 50) - - # Text color embeddings - print("\n📝 Extracting text color embeddings...") - text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples) - print(f" Text color embeddings shape: {text_color_embeddings.shape}") - text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) + # --- text --- + print("\nExtracting GAP-CLIP text embeddings...") + text_full, text_colors, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) + print(f" Text shape: {text_full.shape}") + + # text color + text_color_spec = text_full[:, :self.color_emb_dim] + text_color_metrics = compute_similarity_metrics(text_color_spec, text_colors) text_color_class = self.evaluate_classification_performance( - text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color" + text_color_spec, text_colors, + f"GAP-CLIP Text Color – {dataset_name}", "Color", method="nn", ) text_color_metrics.update(text_color_class) results['text_color'] = text_color_metrics - del text_color_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Image color embeddings - print("\n🖼️ Extracting image color embeddings...") - image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples) - print(f" Image color embeddings shape: {image_color_embeddings.shape}") - image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) - image_color_class = self.evaluate_classification_performance( - image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color" + # text hierarchy + text_hier_spec = text_full[:, self.color_emb_dim:self.hierarchy_end_dim] + text_hier_metrics = compute_similarity_metrics(text_hier_spec, text_hier) + text_hier_class = self.evaluate_classification_performance( + text_hier_spec, text_hier, + f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) - image_color_metrics.update(image_color_class) - results['image_color'] = image_color_metrics - - del image_color_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # ========== HIERARCHY EVALUATION (DIMS 16-79) - SPECIALIZED ONLY ========== - print("\n📋 HIERARCHY EVALUATION (dims 16-79) - Specialized embeddings only") - print("=" * 50) - - # Text hierarchy embeddings - print("\n📝 Extracting text hierarchy embeddings...") - text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples) - print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}") - text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies) - text_hierarchy_class = self.evaluate_classification_performance( - text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy" + text_hier_metrics.update(text_hier_class) + results['text_hierarchy'] = text_hier_metrics + + # --- image --- + print("\nExtracting GAP-CLIP image embeddings...") + img_full, img_colors, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) + + # image color + img_color_spec = img_full[:, :self.color_emb_dim] + img_color_metrics = compute_similarity_metrics(img_color_spec, img_colors) + img_color_class = self.evaluate_classification_performance( + img_color_spec, img_colors, + f"GAP-CLIP Image Color – {dataset_name}", "Color", method="nn", ) - text_hierarchy_metrics.update(text_hierarchy_class) - results['text_hierarchy'] = text_hierarchy_metrics - - del text_hierarchy_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Image hierarchy embeddings - print("\n��️ Extracting image hierarchy embeddings...") - image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples) - print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}") - image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies) - image_hierarchy_class = self.evaluate_classification_performance( - image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy" + img_color_metrics.update(img_color_class) + results['image_color'] = img_color_metrics + + # image hierarchy (best of 64D vs 512D) + img_hier_spec = img_full[:, self.color_emb_dim:self.hierarchy_end_dim] + + spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) + spec_class = self.evaluate_classification_performance( + img_hier_spec, img_hier, + f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn", ) - image_hierarchy_metrics.update(image_hierarchy_class) - results['image_hierarchy'] = image_hierarchy_metrics - del image_hierarchy_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None + full_metrics = compute_similarity_metrics(img_full, img_hier) + full_class = self.evaluate_classification_performance( + img_full, img_hier, + f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn", + ) - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) + if full_class['accuracy'] >= spec_class['accuracy']: + print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%") + img_hier_metrics, img_hier_class = full_metrics, full_class + else: + print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%") + img_hier_metrics, img_hier_class = spec_metrics, spec_class + + img_hier_metrics.update(img_hier_class) + results['image_hierarchy'] = img_hier_metrics + + # --- save confusion matrices --- + prefix = dataset_name.lower().replace(" ", "_") for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: - results[key]['figure'].savefig( - f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', + fig = results[key]['figure'] + fig.savefig( + os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"), + dpi=300, bbox_inches='tight', ) - plt.close(results[key]['figure']) + self.save_confusion_matrix_table( + results[key]['confusion_matrix'], results[key]['labels'], + os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"), + ) + plt.close(fig) - return results + del text_full, img_full + if torch.cuda.is_available(): + torch.cuda.empty_cache() - def evaluate_baseline_fashion_mnist(self, max_samples=10000): - """Evaluate baseline Fashion CLIP model on Fashion-MNIST""" - print(f"\n{'='*60}") - print("Evaluating Baseline Fashion CLIP on Fashion-MNIST") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - # Load Fashion-MNIST dataset - target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes - fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) - - # Create dataloader for Fashion-MNIST - dataloader = DataLoader( - fashion_dataset, - batch_size=8, - shuffle=False, - num_workers=0 - ) - - results = {} - - # Evaluate text embeddings - print("\n📝 Extracting baseline text embeddings from Fashion-MNIST...") - text_embeddings, _, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) - print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") - text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) - text_hierarchy_classification = self.evaluate_classification_performance( - text_embeddings, text_hierarchies, "Baseline Fashion-MNIST Text Embeddings - Hierarchy", "Hierarchy" - ) - - text_hierarchy_metrics.update(text_hierarchy_classification) - results['text'] = { - 'hierarchy': text_hierarchy_metrics - } - - # Clear memory - del text_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Evaluate image embeddings - print("\n🖼️ Extracting baseline image embeddings from Fashion-MNIST...") - image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) - print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") - image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) - - image_hierarchy_classification = self.evaluate_classification_performance( - image_embeddings, image_hierarchies, "Baseline Fashion-MNIST Image Embeddings - Hierarchy", "Hierarchy" - ) - - image_hierarchy_metrics.update(image_hierarchy_classification) - results['image'] = { - 'hierarchy': image_hierarchy_metrics - } - - # Clear memory - del image_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) - for key in ['text', 'image']: - for subkey in ['hierarchy']: - figure = results[key][subkey]['figure'] - figure.savefig( - f"{self.directory}/fashion_baseline_{key}_{subkey}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', - ) - plt.close(figure) - return results - def evaluate_baseline_kaggle_marqo(self, max_samples=10000): - """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" - print(f"\n{'='*60}") - print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - # Load KAGL Marqo dataset - kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) - if kaggle_dataset is None: - print("❌ Failed to load KAGL dataset") - return None - - # Create dataloader - dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) - + def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000): + """Evaluate baseline Fashion-CLIP color + hierarchy on any dataset.""" + print(f"\n{'=' * 60}") + print(f"Evaluating Baseline Fashion-CLIP on {dataset_name} (Color + Hierarchy)") + print(f"{'=' * 60}") + results = {} - - # Evaluate text embeddings - print("\n📝 Extracting baseline text embeddings from KAGL Marqo...") - text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) - print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") - text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) - text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) - - text_color_classification = self.evaluate_classification_performance( - text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color" + + # --- text --- + print("\nExtracting baseline text embeddings...") + text_emb, text_colors, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) + print(f" Baseline text shape: {text_emb.shape}") + + text_color_metrics = compute_similarity_metrics(text_emb, text_colors) + text_color_class = self.evaluate_classification_performance( + text_emb, text_colors, + f"Baseline Text Color – {dataset_name}", "Color", method="nn", ) - text_hierarchy_classification = self.evaluate_classification_performance( - text_embeddings, text_hierarchies, "Baseline KAGL Marqo Text Embeddings - Hierarchy", "Hierarchy" + text_color_metrics.update(text_color_class) + + text_hier_metrics = compute_similarity_metrics(text_emb, text_hier) + text_hier_class = self.evaluate_classification_performance( + text_emb, text_hier, + f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) - - text_color_metrics.update(text_color_classification) - text_hierarchy_metrics.update(text_hierarchy_classification) - results['text'] = { - 'color': text_color_metrics, - 'hierarchy': text_hierarchy_metrics - } - - # Clear memory - del text_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Evaluate image embeddings - print("\n🖼️ Extracting baseline image embeddings from KAGL Marqo...") - image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) - print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") - image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) - image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) - - image_color_classification = self.evaluate_classification_performance( - image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color" + text_hier_metrics.update(text_hier_class) + + results['text'] = {'color': text_color_metrics, 'hierarchy': text_hier_metrics} + + del text_emb + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # --- image --- + print("\nExtracting baseline image embeddings...") + img_emb, img_colors, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) + print(f" Baseline image shape: {img_emb.shape}") + + img_color_metrics = compute_similarity_metrics(img_emb, img_colors) + img_color_class = self.evaluate_classification_performance( + img_emb, img_colors, + f"Baseline Image Color – {dataset_name}", "Color", method="nn", ) - image_hierarchy_classification = self.evaluate_classification_performance( - image_embeddings, image_hierarchies, "Baseline KAGL Marqo Image Embeddings - Hierarchy", "Hierarchy" + img_color_metrics.update(img_color_class) + + img_hier_metrics = compute_similarity_metrics(img_emb, img_hier) + img_hier_class = self.evaluate_classification_performance( + img_emb, img_hier, + f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) - - image_color_metrics.update(image_color_classification) - image_hierarchy_metrics.update(image_hierarchy_classification) - results['image'] = { - 'color': image_color_metrics, - 'hierarchy': image_hierarchy_metrics - } - - # Clear memory - del image_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) + img_hier_metrics.update(img_hier_class) + + results['image'] = {'color': img_color_metrics, 'hierarchy': img_hier_metrics} + + del img_emb + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + prefix = dataset_name.lower().replace(" ", "_") for key in ['text', 'image']: - for subkey in ['color', 'hierarchy']: - figure = results[key][subkey]['figure'] - figure.savefig( - f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', + for attr in ['color', 'hierarchy']: + fig = results[key][attr]['figure'] + fig.savefig( + os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.png"), + dpi=300, bbox_inches='tight', + ) + self.save_confusion_matrix_table( + results[key][attr]['confusion_matrix'], + results[key][attr]['labels'], + os.path.join(self.directory, f"baseline_{prefix}_{key}_{attr}_confusion_matrix.csv"), ) - plt.close(figure) - + plt.close(fig) + return results - def evaluate_baseline_local_validation(self, max_samples=10000): - """Evaluate baseline Fashion CLIP model on local validation dataset""" - print(f"\n{'='*60}") - print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") - print(f"Max samples: {max_samples}") - print(f"{'='*60}") - - # Load local validation dataset - local_dataset = load_local_validation_dataset(max_samples) - if local_dataset is None: - print("❌ Failed to load local validation dataset") - return None - - # Filter to only include hierarchies that exist in our model - if len(local_dataset.dataframe) > 0: - valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] - if len(valid_df) == 0: - print("❌ No samples left after hierarchy filtering.") - return None - if len(valid_df) < len(local_dataset.dataframe): - print(f"📊 Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") - local_dataset = LocalDataset(valid_df) - - # Create dataloader - dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) - - results = {} - - # Evaluate text embeddings - print("\n📝 Extracting baseline text embeddings from Local Validation...") - text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) - print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") - text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) - text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) - - text_color_classification = self.evaluate_classification_performance( - text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color" - ) - text_hierarchy_classification = self.evaluate_classification_performance( - text_embeddings, text_hierarchies, "Baseline Local Validation Text Embeddings - Hierarchy", "Hierarchy" + # ================================================================== + # 6. Full evaluation across all datasets + # ================================================================== + def run_full_evaluation(self, max_samples=10000, batch_size=8): + """Run color + hierarchy evaluation on all 3 datasets for both models.""" + all_results = {} + + # --- Fashion-MNIST --- + shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist( + max_samples=max_samples, batch_size=batch_size, ) - - text_color_metrics.update(text_color_classification) - text_hierarchy_metrics.update(text_hierarchy_classification) - results['text'] = { - 'color': text_color_metrics, - 'hierarchy': text_hierarchy_metrics - } - - # Clear memory - del text_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Evaluate image embeddings - print("\n🖼️ Extracting baseline image embeddings from Local Validation...") - image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) - print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") - image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) - image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) - - image_color_classification = self.evaluate_classification_performance( - image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color" + all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist( + max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) - image_hierarchy_classification = self.evaluate_classification_performance( - image_embeddings, image_hierarchies, "Baseline Local Validation Image Embeddings - Hierarchy", "Hierarchy" + all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist( + max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) - - image_color_metrics.update(image_color_classification) - image_hierarchy_metrics.update(image_hierarchy_classification) - results['image'] = { - 'color': image_color_metrics, - 'hierarchy': image_hierarchy_metrics - } - - # Clear memory - del image_embeddings - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # ========== SAVE VISUALIZATIONS ========== - os.makedirs(self.directory, exist_ok=True) - for key in ['text', 'image']: - for subkey in ['color', 'hierarchy']: - figure = results[key][subkey]['figure'] - figure.savefig( - f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", - dpi=300, - bbox_inches='tight', - ) - plt.close(figure) - - return results + # --- KAGL Marqo --- + try: + kaggle_dataset = load_kaggle_marqo_with_hierarchy( + max_samples=max_samples, + hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, + raw_df=self.kaggle_raw_df, + ) + if kaggle_dataset is not None and len(kaggle_dataset) > 0: + kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0) + all_results['kaggle_gap'] = self.evaluate_gap_clip_generic( + kaggle_dataloader, "KAGL Marqo", max_samples, + ) + all_results['kaggle_baseline'] = self.evaluate_baseline_generic( + kaggle_dataloader, "KAGL Marqo", max_samples, + ) + else: + print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.") + except Exception as e: + print(f"WARNING: Could not evaluate on KAGL Marqo: {e}") + # --- Internal (local validation) --- + try: + local_dataset = load_local_validation_with_hierarchy( + max_samples=max_samples, + hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, + raw_df=self.local_raw_df, + ) + if local_dataset is not None and len(local_dataset) > 0: + local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0) + all_results['local_gap'] = self.evaluate_gap_clip_generic( + local_dataloader, "Internal", max_samples, + ) + all_results['local_baseline'] = self.evaluate_baseline_generic( + local_dataloader, "Internal", max_samples, + ) + else: + print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.") + except Exception as e: + print(f"WARNING: Could not evaluate on internal dataset: {e}") + + # --- Print summary --- + print(f"\n{'=' * 70}") + print("COLOR + HIERARCHY NN ACCURACY EVALUATION SUMMARY (Table 3)") + print(f"{'=' * 70}") + for dataset_key, label in [ + ('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'), + ('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'), + ('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'), + ('kaggle_baseline', 'KAGL Marqo (Baseline)'), + ('local_gap', 'Internal (GAP-CLIP)'), + ('local_baseline', 'Internal (Baseline)'), + ]: + if dataset_key not in all_results: + continue + res = all_results[dataset_key] + print(f"\n{label}:") + # GAP-CLIP format + if 'text_color' in res: + tc = res['text_color'] + ic = res['image_color'] + print(f" Color – Text NN: {tc['accuracy']*100:.1f}% | Image NN: {ic['accuracy']*100:.1f}%") + if 'text_hierarchy' in res: + th = res['text_hierarchy'] + ih = res['image_hierarchy'] + print(f" Hierarchy – Text NN: {th['accuracy']*100:.1f}% | Image NN: {ih['accuracy']*100:.1f}%") + if 'ensemble_accuracy' in ih: + print(f" Hierarchy – Image Ensemble: {ih['ensemble_accuracy']*100:.1f}%") + # Baseline format + if 'text' in res and isinstance(res['text'], dict): + t = res['text'] + i = res['image'] + if 'color' in t: + print(f" Color – Text NN: {t['color']['accuracy']*100:.1f}% | Image NN: {i['color']['accuracy']*100:.1f}%") + if 'hierarchy' in t: + print(f" Hierarchy – Text NN: {t['hierarchy']['accuracy']*100:.1f}% | Image NN: {i['hierarchy']['accuracy']*100:.1f}%") + + return all_results + + +# ============================================================================ +# 7. Main +# ============================================================================ if __name__ == "__main__": device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") - directory = 'figures/confusion_matrices' + directory = 'main_model_analysis' max_samples = 10000 evaluator = ColorHierarchyEvaluator(device=device, directory=directory) - - # Evaluate Fashion-MNIST - print("\n" + "="*60) - print("🚀 Starting evaluation of Fashion-MNIST Hierarchy embeddings") - print("="*60) - results_fashion = evaluator.evaluate_fashion_mnist(max_samples=max_samples) - - print(f"\n{'='*60}") - print("FASHION-MNIST EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") - print(f" Text - NN Acc: {results_fashion['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_fashion['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_hierarchy']['separation_score']:.4f}") - - # Evaluate Baseline Fashion CLIP on Fashion-MNIST - print("\n" + "="*60) - print("🚀 Starting evaluation of Baseline Fashion CLIP on Fashion-MNIST") - print("="*60) - results_baseline = evaluator.evaluate_baseline_fashion_mnist(max_samples=max_samples) - - print(f"\n{'='*60}") - print("BASELINE FASHION-MNIST EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") - print(f" Text - NN Acc: {results_baseline['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_baseline['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['hierarchy']['separation_score']:.4f}") - - - # Evaluate KAGL Marqo - print("\n" + "="*60) - print("🚀 Starting evaluation of KAGL Marqo with Color & Hierarchy embeddings") - print("="*60) - results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples) - - if results_kaggle is not None: - print(f"\n{'='*60}") - print("KAGL MARQO EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n🎨 COLOR CLASSIFICATION RESULTS (dims 0-15):") - print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") - print(f" Text - NN Acc: {results_kaggle['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_kaggle['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_hierarchy']['separation_score']:.4f}") - - # Evaluate Baseline Fashion CLIP on KAGL Marqo - print("\n" + "="*60) - print("🚀 Starting evaluation of Baseline Fashion CLIP on KAGL Marqo") - print("="*60) - results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) - - if results_baseline_kaggle is not None: - print(f"\n{'='*60}") - print("BASELINE KAGL MARQO EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):") - print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") - print(f" Text - NN Acc: {results_baseline_kaggle['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_baseline_kaggle['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['hierarchy']['separation_score']:.4f}") - - # Evaluate Local Validation Dataset - print("\n" + "="*60) - print("🚀 Starting evaluation of Local Validation Dataset with Color & Hierarchy embeddings") - print("="*60) - results_local = evaluator.evaluate_local_validation(max_samples=max_samples) - - if results_local is not None: - print(f"\n{'='*60}") - print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n🎨 COLOR CLASSIFICATION RESULTS (dims 0-15):") - print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") - print(f" Text - NN Acc: {results_local['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_local['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_hierarchy']['separation_score']:.4f}") - - # Evaluate Baseline Fashion CLIP on Local Validation - print("\n" + "="*60) - print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation") - print("="*60) - results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples) - - if results_baseline_local is not None: - print(f"\n{'='*60}") - print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") - print(f"{'='*60}") - - print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):") - print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") - - print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") - print(f" Text - NN Acc: {results_baseline_local['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['hierarchy']['separation_score']:.4f}") - print(f" Image - NN Acc: {results_baseline_local['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['hierarchy']['separation_score']:.4f}") + evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)