| |
| """ |
| GAP-CLIP Evaluation Runner |
| =========================== |
| |
| Orchestrates all evaluation scripts, one per paper section. Each evaluation |
| is independent and can be run in isolation via ``--steps``. |
| |
| Usage |
| ----- |
| Run everything:: |
| |
| python evaluation/run_all_evaluations.py |
| |
| Run specific sections:: |
| |
| python evaluation/run_all_evaluations.py --steps sec51,sec52 |
| python evaluation/run_all_evaluations.py --steps annex92,annex93 |
| |
| Available steps |
| --------------- |
| sec51 §5.1 Colour model accuracy (Table 1) |
| sec52 §5.2 Category model confusion matrix (Table 2) |
| sec533 §5.3.3 NN classification accuracy (Table 3) |
| sec536 §5.3.6 Embedding structure Tests A/B/C/D (Table 4) |
| annex92 Annex 9.2 Pairwise colour similarity heatmaps |
| annex93 Annex 9.3 t-SNE visualisations |
| annex94 Annex 9.4 Fashion search demo |
| |
| Author: Lea Attia Sarfati |
| """ |
|
|
| import argparse |
| import sys |
| import traceback |
| from datetime import datetime |
| from pathlib import Path |
|
|
| |
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| ALL_STEPS = ["sec51", "sec52", "sec533", "sec536", "annex92", "annex93", "annex94"] |
|
|
|
|
| class ResourceCache: |
| """Lazy-loading cache for shared models and raw datasets. |
| |
| Each property is loaded at most once and cached for reuse across |
| evaluation sections. This avoids re-downloading Kaggle data (~30s), |
| re-loading Fashion-CLIP (~15s) and GAP-CLIP (~20s) multiple times. |
| """ |
|
|
| def __init__(self, device=None): |
| import torch |
| if device is None: |
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
| self.device = torch.device(device) if isinstance(device, str) else device |
|
|
| self._gap_clip = None |
| self._fashion_clip = None |
| self._color_model = None |
| self._hierarchy_classes = None |
| self._kaggle_raw_df = None |
| self._local_raw_df = None |
|
|
| @property |
| def gap_clip(self): |
| """(model, processor) for GAP-CLIP.""" |
| if self._gap_clip is None: |
| from config import main_model_path |
| from utils.model_loader import load_gap_clip |
| print("[ResourceCache] Loading GAP-CLIP...") |
| self._gap_clip = load_gap_clip(main_model_path, self.device) |
| return self._gap_clip |
|
|
| @property |
| def fashion_clip(self): |
| """(model, processor) for Fashion-CLIP baseline.""" |
| if self._fashion_clip is None: |
| from utils.model_loader import load_baseline_fashion_clip |
| print("[ResourceCache] Loading Fashion-CLIP baseline...") |
| self._fashion_clip = load_baseline_fashion_clip(self.device) |
| return self._fashion_clip |
|
|
| @property |
| def color_model(self): |
| """ColorCLIP model instance.""" |
| if self._color_model is None: |
| from config import color_model_path |
| from utils.model_loader import load_color_model |
| print("[ResourceCache] Loading ColorCLIP model...") |
| self._color_model, _ = load_color_model(color_model_path, self.device) |
| return self._color_model |
|
|
| @property |
| def hierarchy_classes(self): |
| """List of hierarchy class names from the hierarchy model checkpoint.""" |
| if self._hierarchy_classes is None: |
| import torch |
| from config import hierarchy_model_path |
| print("[ResourceCache] Loading hierarchy classes...") |
| checkpoint = torch.load(hierarchy_model_path, map_location=self.device) |
| self._hierarchy_classes = checkpoint.get('hierarchy_classes', []) |
| print(f"[ResourceCache] Found {len(self._hierarchy_classes)} hierarchy classes") |
| return self._hierarchy_classes |
|
|
| @property |
| def kaggle_raw_df(self): |
| """Raw Kaggle KAGL DataFrame (downloaded once from HuggingFace).""" |
| if self._kaggle_raw_df is None: |
| from utils.datasets import download_kaggle_raw_df |
| print("[ResourceCache] Downloading Kaggle KAGL dataset...") |
| self._kaggle_raw_df = download_kaggle_raw_df() |
| return self._kaggle_raw_df |
|
|
| @property |
| def local_raw_df(self): |
| """Raw local validation DataFrame (read once from CSV).""" |
| if self._local_raw_df is None: |
| import pandas as pd |
| from config import local_dataset_path |
| print("[ResourceCache] Loading local validation CSV...") |
| self._local_raw_df = pd.read_csv(local_dataset_path) |
| print(f"[ResourceCache] Local dataset: {len(self._local_raw_df)} rows") |
| return self._local_raw_df |
|
|
|
|
| class EvaluationRunner: |
| """Runs one or more evaluation sections and collects pass/fail status.""" |
|
|
| def __init__(self, output_dir: str = "evaluation_results"): |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(exist_ok=True, parents=True) |
| self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| self.results: dict[str, str] = {} |
| self.cache = ResourceCache() |
|
|
| |
| |
| |
|
|
| def run_sec51(self): |
| """§5.1 – Colour model accuracy (Table 1).""" |
| from sec51_color_model_eval import ColorEvaluator |
| baseline_model, baseline_processor = self.cache.fashion_clip |
| evaluator = ColorEvaluator( |
| device=self.cache.device, |
| directory=str(self.output_dir / "sec51"), |
| baseline_model=baseline_model, |
| baseline_processor=baseline_processor, |
| color_model=self.cache.color_model, |
| kaggle_raw_df=self.cache.kaggle_raw_df, |
| local_raw_df=self.cache.local_raw_df, |
| ) |
| max_samples = 5000 |
| evaluator.evaluate_kaggle_marqo(max_samples=max_samples) |
| evaluator.evaluate_local_validation(max_samples=max_samples) |
| evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) |
| evaluator.evaluate_baseline_local_validation(max_samples=max_samples) |
|
|
| def run_sec52(self): |
| """§5.2 – Category model confusion matrix (Table 2).""" |
| from sec52_category_model_eval import CategoryModelEvaluator |
| gap_model, gap_processor = self.cache.gap_clip |
| baseline_model, baseline_processor = self.cache.fashion_clip |
| evaluator = CategoryModelEvaluator( |
| device=self.cache.device, |
| directory=str(self.output_dir / "sec52"), |
| gap_clip_model=gap_model, |
| gap_clip_processor=gap_processor, |
| baseline_model=baseline_model, |
| baseline_processor=baseline_processor, |
| hierarchy_classes=self.cache.hierarchy_classes, |
| kaggle_raw_df=self.cache.kaggle_raw_df, |
| local_raw_df=self.cache.local_raw_df, |
| ) |
| evaluator.run_full_evaluation() |
|
|
| def run_sec533(self): |
| """§5.3.3 – Nearest-neighbour classification accuracy (Table 3).""" |
| from sec533_clip_nn_accuracy import ColorHierarchyEvaluator |
| gap_model, gap_processor = self.cache.gap_clip |
| baseline_model, baseline_processor = self.cache.fashion_clip |
| evaluator = ColorHierarchyEvaluator( |
| device=self.cache.device, |
| directory=str(self.output_dir / "sec533"), |
| gap_clip_model=gap_model, |
| gap_clip_processor=gap_processor, |
| baseline_model=baseline_model, |
| baseline_processor=baseline_processor, |
| hierarchy_classes=self.cache.hierarchy_classes, |
| kaggle_raw_df=self.cache.kaggle_raw_df, |
| local_raw_df=self.cache.local_raw_df, |
| ) |
| evaluator.run_full_evaluation(max_samples=10_000) |
|
|
| def run_sec536(self): |
| """§5.3.6 – Embedding structure Tests A/B/C/D.""" |
| from sec536_embedding_structure import main as sec536_main |
| gap_model, gap_processor = self.cache.gap_clip |
| baseline_model, baseline_processor = self.cache.fashion_clip |
| sec536_main( |
| selected_tests={"A", "B", "C", "D"}, |
| model=gap_model, |
| processor=gap_processor, |
| baseline_model=baseline_model, |
| baseline_processor=baseline_processor, |
| ) |
|
|
| def run_annex92(self): |
| """Annex 9.2 – Pairwise colour similarity heatmaps.""" |
| |
| import runpy |
| runpy.run_path( |
| str(Path(__file__).parent / "annex92_color_heatmaps.py"), |
| run_name="__main__", |
| ) |
|
|
| def run_annex93(self): |
| """Annex 9.3 – t-SNE visualisations.""" |
| import runpy |
| runpy.run_path( |
| str(Path(__file__).parent / "annex93_tsne.py"), |
| run_name="__main__", |
| ) |
|
|
| def run_annex94(self): |
| """Annex 9.4 – Fashion search demo.""" |
| import runpy |
| runpy.run_path( |
| str(Path(__file__).parent / "annex94_search_demo.py"), |
| run_name="__main__", |
| ) |
|
|
| |
| |
| |
|
|
| def _run_step(self, step: str) -> bool: |
| method = getattr(self, f"run_{step.replace('-', '_')}", None) |
| if method is None: |
| print(f"⚠️ Unknown step '{step}' – skipping.") |
| self.results[step] = "skipped" |
| return False |
|
|
| print(f"\n{'='*70}") |
| print(f"▶ Running {step} ({method.__doc__ or ''})") |
| print(f"{'='*70}") |
| try: |
| method() |
| self.results[step] = "ok" |
| print(f"✅ {step} completed successfully.") |
| return True |
| except Exception: |
| self.results[step] = "failed" |
| print(f"❌ {step} FAILED:") |
| traceback.print_exc() |
| return False |
|
|
| def run(self, steps: list[str]) -> bool: |
| print("=" * 70) |
| print(f"🚀 GAP-CLIP Evaluation ({self.timestamp})") |
| print(f" Steps: {', '.join(steps)}") |
| print(f" Output: {self.output_dir}") |
| print("=" * 70) |
|
|
| for step in steps: |
| self._run_step(step) |
|
|
| |
| print(f"\n{'='*70}") |
| print("📊 Summary") |
| print(f"{'='*70}") |
| all_ok = True |
| for step in steps: |
| status = self.results.get(step, "skipped") |
| icon = {"ok": "✅", "failed": "❌", "skipped": "⚠️ "}.get(status, "?") |
| print(f" {icon} {step:15s} {status}") |
| if status == "failed": |
| all_ok = False |
|
|
| print("=" * 70) |
| return all_ok |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Run GAP-CLIP evaluations.", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog="\n".join( |
| [ |
| "Available steps:", |
| " sec51 §5.1 Colour model (Table 1)", |
| " sec52 §5.2 Category model (Table 2)", |
| " sec533 §5.3.3 NN accuracy (Table 3)", |
| " sec536 §5.3.6 Embedding structure tests A/B/C/D (Table 4)", |
| " annex92 Annex 9.2 Colour heatmaps", |
| " annex93 Annex 9.3 t-SNE", |
| " annex94 Annex 9.4 Search demo", |
| ] |
| ), |
| ) |
| parser.add_argument( |
| "--steps", |
| type=str, |
| default="all", |
| help=( |
| "Comma-separated list of steps to run, or 'all' to run everything " |
| "(default: all). Example: --steps sec51,sec52,sec536" |
| ), |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default="evaluation_results", |
| help="Directory to save results (default: evaluation_results).", |
| ) |
| args = parser.parse_args() |
|
|
| if args.steps.strip().lower() == "all": |
| steps = ALL_STEPS |
| else: |
| steps = [s.strip() for s in args.steps.split(",") if s.strip()] |
|
|
| runner = EvaluationRunner(output_dir=args.output) |
| success = runner.run(steps) |
| sys.exit(0 if success else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|