| """ |
| Shared dataset classes and loading utilities for GAP-CLIP evaluation scripts. |
| |
| Provides: |
| - FashionMNISTDataset (Fashion-MNIST grayscale images) |
| - KaggleDataset (KAGL Marqo HuggingFace dataset) |
| - LocalDataset (internal local validation dataset) |
| - Matching load_* convenience functions |
| - collate_fn_filter_none (for DataLoader) |
| - normalize_hierarchy_label (text normalisation helper) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import difflib |
| import hashlib |
| import os |
| import sys |
| from pathlib import Path |
| from io import BytesIO |
| from typing import List, Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
| import requests |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
|
|
| |
| _PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
| from config import ( |
| ROOT_DIR, |
| column_local_image_path, |
| fashion_mnist_csv, |
| local_dataset_path, |
| images_dir, |
| ) |
|
|
| |
| |
| |
|
|
| def get_fashion_mnist_labels() -> dict: |
| """Return the 10 Fashion-MNIST integer-to-name mapping.""" |
| return { |
| 0: "T-shirt/top", |
| 1: "Trouser", |
| 2: "Pullover", |
| 3: "Dress", |
| 4: "Coat", |
| 5: "Sandal", |
| 6: "Shirt", |
| 7: "Sneaker", |
| 8: "Bag", |
| 9: "Ankle boot", |
| } |
|
|
|
|
| def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes: List[str]) -> dict: |
| """Map Fashion-MNIST integer labels to nearest hierarchy class name. |
| |
| Returns dict {label_id: matched_class_name or None}. |
| """ |
| fashion_mnist_labels = get_fashion_mnist_labels() |
| hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
| mapping = {} |
|
|
| for fm_label_id, fm_label in fashion_mnist_labels.items(): |
| fm_label_lower = fm_label.lower() |
| matched_hierarchy = None |
|
|
| if fm_label_lower in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)] |
| elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower): |
| for h_class in hierarchy_classes: |
| if h_class.lower() in fm_label_lower or fm_label_lower in h_class.lower(): |
| matched_hierarchy = h_class |
| break |
| else: |
| if fm_label_lower in ["t-shirt/top", "top"]: |
| if "top" in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("top")] |
| elif "trouser" in fm_label_lower: |
| for p in ["bottom", "pants", "trousers", "trouser", "pant"]: |
| if p in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)] |
| break |
| elif "pullover" in fm_label_lower: |
| for p in ["sweater", "pullover"]: |
| if p in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)] |
| break |
| elif "dress" in fm_label_lower: |
| if "dress" in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("dress")] |
| elif "coat" in fm_label_lower: |
| for p in ["jacket", "outerwear", "coat"]: |
| if p in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)] |
| break |
| elif fm_label_lower in ["sandal", "sneaker", "ankle boot"]: |
| for p in ["shoes", "shoe", "sandal", "sneaker", "boot"]: |
| if p in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)] |
| break |
| elif "bag" in fm_label_lower: |
| if "bag" in hierarchy_classes_lower: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("bag")] |
|
|
| if matched_hierarchy is None: |
| close = difflib.get_close_matches(fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6) |
| if close: |
| matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close[0])] |
|
|
| mapping[fm_label_id] = matched_hierarchy |
| status = matched_hierarchy if matched_hierarchy else "NO MATCH (will be filtered out)" |
| print(f" {fm_label} ({fm_label_id}) -> {status}") |
|
|
| return mapping |
|
|
|
|
| def convert_fashion_mnist_to_image(pixel_values) -> Image.Image: |
| """Convert a flat 784-element pixel array to an RGB PIL image.""" |
| arr = np.array(pixel_values).reshape(28, 28).astype(np.uint8) |
| arr = np.stack([arr] * 3, axis=-1) |
| return Image.fromarray(arr) |
|
|
|
|
| class FashionMNISTDataset(Dataset): |
| """PyTorch dataset wrapping Fashion-MNIST CSV rows.""" |
|
|
| def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, label_mapping: Optional[dict] = None): |
| self.dataframe = dataframe |
| self.image_size = image_size |
| self.labels_map = get_fashion_mnist_labels() |
| self.label_mapping = label_mapping |
|
|
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self) -> int: |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| pixel_cols = [f"pixel{i}" for i in range(1, 785)] |
| image = convert_fashion_mnist_to_image(row[pixel_cols].values) |
| image = self.transform(image) |
|
|
| label_id = int(row["label"]) |
| description = self.labels_map[label_id] |
| color = "unknown" |
| hierarchy = ( |
| self.label_mapping[label_id] |
| if (self.label_mapping and label_id in self.label_mapping) |
| else self.labels_map[label_id] |
| ) |
| return image, description, color, hierarchy |
|
|
|
|
| def load_fashion_mnist_dataset( |
| max_samples: int = 10000, |
| hierarchy_classes: Optional[List[str]] = None, |
| csv_path: Optional[str] = None, |
| ) -> FashionMNISTDataset: |
| """Load Fashion-MNIST test CSV into a FashionMNISTDataset. |
| |
| Args: |
| max_samples: Maximum number of samples to use. |
| hierarchy_classes: If provided, maps Fashion-MNIST labels to these classes. |
| csv_path: Path to fashion-mnist_test.csv. Defaults to config.fashion_mnist_csv. |
| """ |
| if csv_path is None: |
| csv_path = fashion_mnist_csv |
|
|
| print("Loading Fashion-MNIST test dataset...") |
| df = pd.read_csv(csv_path) |
| print(f"Fashion-MNIST dataset loaded: {len(df)} samples") |
|
|
| label_mapping = None |
| if hierarchy_classes is not None: |
| print("\nCreating mapping from Fashion-MNIST labels to hierarchy classes:") |
| label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes) |
| valid_ids = [lid for lid, h in label_mapping.items() if h is not None] |
| df = df[df["label"].isin(valid_ids)] |
| print(f"\nAfter filtering to mappable labels: {len(df)} samples") |
|
|
| df_sample = df.head(max_samples) |
| print(f"Using {len(df_sample)} samples for evaluation") |
| return FashionMNISTDataset(df_sample, label_mapping=label_mapping) |
|
|
|
|
| |
| |
| |
|
|
| class KaggleDataset(Dataset): |
| """Dataset class for KAGL Marqo HuggingFace dataset.""" |
|
|
| def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, include_hierarchy: bool = False): |
| self.dataframe = dataframe |
| self.image_size = image_size |
| self.include_hierarchy = include_hierarchy |
|
|
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self) -> int: |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| image_data = row["image_url"] |
|
|
| if isinstance(image_data, dict) and "bytes" in image_data: |
| image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") |
| elif hasattr(image_data, "convert"): |
| image = image_data.convert("RGB") |
| else: |
| image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
| image = self.transform(image) |
| description = row["text"] |
| color = row["color"] |
|
|
| if self.include_hierarchy: |
| hierarchy = row.get("hierarchy", "unknown") |
| return image, description, color, hierarchy |
| return image, description, color |
|
|
|
|
| def download_kaggle_raw_df() -> pd.DataFrame: |
| """Download the raw KAGL Marqo DataFrame from HuggingFace. |
| |
| This is the expensive network operation. Callers can cache the result |
| and pass it to :func:`load_kaggle_marqo_dataset` via *raw_df* to avoid |
| repeated downloads. |
| """ |
| from datasets import load_dataset |
|
|
| print("Downloading KAGL Marqo dataset from HuggingFace...") |
| dataset = load_dataset("Marqo/KAGL") |
| df = dataset["data"].to_pandas() |
| print(f"KAGL dataset downloaded: {len(df)} samples, columns: {list(df.columns)}") |
| return df |
|
|
|
|
| def load_kaggle_marqo_dataset( |
| max_samples: int = 5000, |
| include_hierarchy: bool = False, |
| raw_df: Optional[pd.DataFrame] = None, |
| ) -> KaggleDataset: |
| """Download and prepare the KAGL Marqo HuggingFace dataset. |
| |
| Args: |
| max_samples: Maximum number of samples to return. |
| include_hierarchy: If True, dataset tuples include a hierarchy element. |
| raw_df: Pre-downloaded DataFrame (from :func:`download_kaggle_raw_df`). |
| If *None*, the dataset is downloaded from HuggingFace. |
| """ |
| if raw_df is not None: |
| df = raw_df.copy() |
| print(f"Using cached KAGL DataFrame: {len(df)} samples") |
| else: |
| df = download_kaggle_raw_df() |
|
|
| df = df.dropna(subset=["text", "image"]) |
|
|
| if len(df) > max_samples: |
| df = df.sample(n=max_samples, random_state=42) |
| print(f"Sampled {max_samples} items") |
|
|
| kaggle_df = pd.DataFrame({ |
| "image_url": df["image"], |
| "text": df["text"], |
| "color": df["baseColour"].str.lower().str.replace("grey", "gray"), |
| }) |
|
|
| kaggle_df = kaggle_df.dropna(subset=["color"]) |
|
|
| print(f"Colors: {sorted(kaggle_df['color'].unique())}") |
|
|
| return KaggleDataset(kaggle_df, include_hierarchy=include_hierarchy) |
|
|
|
|
| |
| |
| |
|
|
| class LocalDataset(Dataset): |
| """Dataset class for the internal local validation dataset.""" |
|
|
| def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, include_hierarchy: bool = False): |
| self.dataframe = dataframe |
| self.image_size = image_size |
| self.include_hierarchy = include_hierarchy |
|
|
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self) -> int: |
| return len(self.dataframe) |
|
|
| def __getitem__(self, idx): |
| row = self.dataframe.iloc[idx] |
| try: |
| image_path = row.get(column_local_image_path) if hasattr(row, "get") else None |
| if isinstance(image_path, str) and image_path: |
| if not os.path.isabs(image_path): |
| image_path = str(ROOT_DIR / image_path) |
| image = Image.open(image_path).convert("RGB") |
| else: |
| |
| image_url = row.get("image_url") if hasattr(row, "get") else None |
| if isinstance(image_url, dict) and "bytes" in image_url: |
| image = Image.open(BytesIO(image_url["bytes"])).convert("RGB") |
| elif isinstance(image_url, str) and image_url: |
| cache_dir = Path(images_dir) |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest() |
| cache_path = cache_dir / f"{url_hash}.jpg" |
| if cache_path.exists(): |
| image = Image.open(cache_path).convert("RGB") |
| else: |
| resp = requests.get(image_url, timeout=10) |
| resp.raise_for_status() |
| image = Image.open(BytesIO(resp.content)).convert("RGB") |
| image.save(cache_path, "JPEG", quality=85, optimize=True) |
| else: |
| raise ValueError("Missing image_path and image_url") |
| except Exception as e: |
| print(f"Error loading image: {e}") |
| image = Image.new("RGB", (224, 224), color="gray") |
| image = self.transform(image) |
|
|
| description = row["text"] |
| color = row["color"] |
|
|
| if self.include_hierarchy: |
| hierarchy = row.get("hierarchy", "unknown") |
| return image, description, color, hierarchy |
| return image, description, color |
|
|
|
|
| def load_local_validation_dataset( |
| max_samples: int = 5000, |
| include_hierarchy: bool = False, |
| raw_df: Optional[pd.DataFrame] = None, |
| ) -> LocalDataset: |
| """Load and prepare the internal local validation dataset. |
| |
| Args: |
| max_samples: Maximum number of samples to return. |
| include_hierarchy: If True, dataset tuples include a hierarchy element. |
| raw_df: Pre-loaded DataFrame. If *None*, the CSV is read from disk. |
| """ |
| if raw_df is not None: |
| df = raw_df.copy() |
| print(f"Using cached local DataFrame: {len(df)} samples") |
| else: |
| print("Loading local validation dataset...") |
| df = pd.read_csv(local_dataset_path) |
| print(f"Dataset loaded: {len(df)} samples") |
|
|
| if column_local_image_path in df.columns: |
| df = df.dropna(subset=[column_local_image_path]) |
| print(f"After filtering NaN image paths: {len(df)} samples") |
| else: |
| print(f"Column '{column_local_image_path}' not found; falling back to 'image_url'.") |
|
|
| if "color" in df.columns: |
| print(f"After color filtering: {len(df)} samples, colors: {sorted(df['color'].unique())}") |
|
|
| if len(df) > max_samples: |
| df = df.sample(n=max_samples, random_state=42) |
| print(f"Sampled {max_samples} items") |
|
|
| print(f"Using {len(df)} samples for evaluation") |
| return LocalDataset(df, include_hierarchy=include_hierarchy) |
|
|
|
|
| |
| |
| |
|
|
| def collate_fn_filter_none(batch): |
| """Collate function that silently drops None items from a batch.""" |
| original_len = len(batch) |
| batch = [item for item in batch if item is not None] |
| if original_len > len(batch): |
| print(f"Filtered out {original_len - len(batch)} None values from batch") |
| if not batch: |
| print("Empty batch after filtering None values") |
| return torch.tensor([]), [], [] |
| |
| if len(batch[0]) == 4: |
| images, texts, colors, hierarchies = zip(*batch) |
| return torch.stack(images), list(texts), list(colors), list(hierarchies) |
| images, texts, colors = zip(*batch) |
| return torch.stack(images), list(texts), list(colors) |
|
|
|
|
| |
| |
| |
|
|
| def normalize_hierarchy_label(label: str) -> str: |
| """Lower-case and strip a hierarchy label for consistent comparison.""" |
| return label.lower().strip() if label else "" |
|
|