gap-clip / evaluation /run_all_evaluations.py
Leacb4's picture
Upload evaluation/run_all_evaluations.py with huggingface_hub
7a5206c verified
#!/usr/bin/env python3
"""
GAP-CLIP Evaluation Runner
===========================
Orchestrates all evaluation scripts, one per paper section. Each evaluation
is independent and can be run in isolation via ``--steps``.
Usage
-----
Run everything::
python evaluation/run_all_evaluations.py
Run specific sections::
python evaluation/run_all_evaluations.py --steps sec51,sec52
python evaluation/run_all_evaluations.py --steps annex92,annex93
Available steps
---------------
sec51 §5.1 Colour model accuracy (Table 1)
sec52 §5.2 Category model confusion matrix (Table 2)
sec533 §5.3.3 NN classification accuracy (Table 3)
sec536 §5.3.6 Embedding structure Tests A/B/C/D (Table 4)
annex92 Annex 9.2 Pairwise colour similarity heatmaps
annex93 Annex 9.3 t-SNE visualisations
annex94 Annex 9.4 Fashion search demo
Author: Lea Attia Sarfati
"""
import argparse
import sys
import traceback
from datetime import datetime
from pathlib import Path
# Make sure the repo root is on the path so that `config` is importable,
# and the evaluation directory so that secXX modules can be imported.
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent))
ALL_STEPS = ["sec51", "sec52", "sec533", "sec536", "annex92", "annex93", "annex94"]
class ResourceCache:
"""Lazy-loading cache for shared models and raw datasets.
Each property is loaded at most once and cached for reuse across
evaluation sections. This avoids re-downloading Kaggle data (~30s),
re-loading Fashion-CLIP (~15s) and GAP-CLIP (~20s) multiple times.
"""
def __init__(self, device=None):
import torch
if device is None:
device = "mps" if torch.backends.mps.is_available() else "cpu"
self.device = torch.device(device) if isinstance(device, str) else device
self._gap_clip = None
self._fashion_clip = None
self._color_model = None
self._hierarchy_classes = None
self._kaggle_raw_df = None
self._local_raw_df = None
@property
def gap_clip(self):
"""(model, processor) for GAP-CLIP."""
if self._gap_clip is None:
from config import main_model_path
from utils.model_loader import load_gap_clip
print("[ResourceCache] Loading GAP-CLIP...")
self._gap_clip = load_gap_clip(main_model_path, self.device)
return self._gap_clip
@property
def fashion_clip(self):
"""(model, processor) for Fashion-CLIP baseline."""
if self._fashion_clip is None:
from utils.model_loader import load_baseline_fashion_clip
print("[ResourceCache] Loading Fashion-CLIP baseline...")
self._fashion_clip = load_baseline_fashion_clip(self.device)
return self._fashion_clip
@property
def color_model(self):
"""ColorCLIP model instance."""
if self._color_model is None:
from config import color_model_path
from utils.model_loader import load_color_model
print("[ResourceCache] Loading ColorCLIP model...")
self._color_model, _ = load_color_model(color_model_path, self.device)
return self._color_model
@property
def hierarchy_classes(self):
"""List of hierarchy class names from the hierarchy model checkpoint."""
if self._hierarchy_classes is None:
import torch
from config import hierarchy_model_path
print("[ResourceCache] Loading hierarchy classes...")
checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
self._hierarchy_classes = checkpoint.get('hierarchy_classes', [])
print(f"[ResourceCache] Found {len(self._hierarchy_classes)} hierarchy classes")
return self._hierarchy_classes
@property
def kaggle_raw_df(self):
"""Raw Kaggle KAGL DataFrame (downloaded once from HuggingFace)."""
if self._kaggle_raw_df is None:
from utils.datasets import download_kaggle_raw_df
print("[ResourceCache] Downloading Kaggle KAGL dataset...")
self._kaggle_raw_df = download_kaggle_raw_df()
return self._kaggle_raw_df
@property
def local_raw_df(self):
"""Raw local validation DataFrame (read once from CSV)."""
if self._local_raw_df is None:
import pandas as pd
from config import local_dataset_path
print("[ResourceCache] Loading local validation CSV...")
self._local_raw_df = pd.read_csv(local_dataset_path)
print(f"[ResourceCache] Local dataset: {len(self._local_raw_df)} rows")
return self._local_raw_df
class EvaluationRunner:
"""Runs one or more evaluation sections and collects pass/fail status."""
def __init__(self, output_dir: str = "evaluation_results"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True, parents=True)
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.results: dict[str, str] = {} # step -> "ok" | "failed" | "skipped"
self.cache = ResourceCache()
# ------------------------------------------------------------------
# Individual section runners (lazy imports to allow partial execution)
# ------------------------------------------------------------------
def run_sec51(self):
"""§5.1 – Colour model accuracy (Table 1)."""
from sec51_color_model_eval import ColorEvaluator
baseline_model, baseline_processor = self.cache.fashion_clip
evaluator = ColorEvaluator(
device=self.cache.device,
directory=str(self.output_dir / "sec51"),
baseline_model=baseline_model,
baseline_processor=baseline_processor,
color_model=self.cache.color_model,
kaggle_raw_df=self.cache.kaggle_raw_df,
local_raw_df=self.cache.local_raw_df,
)
max_samples = 5000
evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
evaluator.evaluate_local_validation(max_samples=max_samples)
evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
def run_sec52(self):
"""§5.2 – Category model confusion matrix (Table 2)."""
from sec52_category_model_eval import CategoryModelEvaluator
gap_model, gap_processor = self.cache.gap_clip
baseline_model, baseline_processor = self.cache.fashion_clip
evaluator = CategoryModelEvaluator(
device=self.cache.device,
directory=str(self.output_dir / "sec52"),
gap_clip_model=gap_model,
gap_clip_processor=gap_processor,
baseline_model=baseline_model,
baseline_processor=baseline_processor,
hierarchy_classes=self.cache.hierarchy_classes,
kaggle_raw_df=self.cache.kaggle_raw_df,
local_raw_df=self.cache.local_raw_df,
)
evaluator.run_full_evaluation()
def run_sec533(self):
"""§5.3.3 – Nearest-neighbour classification accuracy (Table 3)."""
from sec533_clip_nn_accuracy import ColorHierarchyEvaluator
gap_model, gap_processor = self.cache.gap_clip
baseline_model, baseline_processor = self.cache.fashion_clip
evaluator = ColorHierarchyEvaluator(
device=self.cache.device,
directory=str(self.output_dir / "sec533"),
gap_clip_model=gap_model,
gap_clip_processor=gap_processor,
baseline_model=baseline_model,
baseline_processor=baseline_processor,
hierarchy_classes=self.cache.hierarchy_classes,
kaggle_raw_df=self.cache.kaggle_raw_df,
local_raw_df=self.cache.local_raw_df,
)
evaluator.run_full_evaluation(max_samples=10_000)
def run_sec536(self):
"""§5.3.6 – Embedding structure Tests A/B/C/D."""
from sec536_embedding_structure import main as sec536_main
gap_model, gap_processor = self.cache.gap_clip
baseline_model, baseline_processor = self.cache.fashion_clip
sec536_main(
selected_tests={"A", "B", "C", "D"},
model=gap_model,
processor=gap_processor,
baseline_model=baseline_model,
baseline_processor=baseline_processor,
)
def run_annex92(self):
"""Annex 9.2 – Pairwise colour similarity heatmaps."""
# annex92 is a self-contained script; run its __main__ guard.
import runpy
runpy.run_path(
str(Path(__file__).parent / "annex92_color_heatmaps.py"),
run_name="__main__",
)
def run_annex93(self):
"""Annex 9.3 – t-SNE visualisations."""
import runpy
runpy.run_path(
str(Path(__file__).parent / "annex93_tsne.py"),
run_name="__main__",
)
def run_annex94(self):
"""Annex 9.4 – Fashion search demo."""
import runpy
runpy.run_path(
str(Path(__file__).parent / "annex94_search_demo.py"),
run_name="__main__",
)
# ------------------------------------------------------------------
# Orchestration
# ------------------------------------------------------------------
def _run_step(self, step: str) -> bool:
method = getattr(self, f"run_{step.replace('-', '_')}", None)
if method is None:
print(f"⚠️ Unknown step '{step}' – skipping.")
self.results[step] = "skipped"
return False
print(f"\n{'='*70}")
print(f"▶ Running {step} ({method.__doc__ or ''})")
print(f"{'='*70}")
try:
method()
self.results[step] = "ok"
print(f"✅ {step} completed successfully.")
return True
except Exception:
self.results[step] = "failed"
print(f"❌ {step} FAILED:")
traceback.print_exc()
return False
def run(self, steps: list[str]) -> bool:
print("=" * 70)
print(f"🚀 GAP-CLIP Evaluation ({self.timestamp})")
print(f" Steps: {', '.join(steps)}")
print(f" Output: {self.output_dir}")
print("=" * 70)
for step in steps:
self._run_step(step)
# Summary
print(f"\n{'='*70}")
print("📊 Summary")
print(f"{'='*70}")
all_ok = True
for step in steps:
status = self.results.get(step, "skipped")
icon = {"ok": "✅", "failed": "❌", "skipped": "⚠️ "}.get(status, "?")
print(f" {icon} {step:15s} {status}")
if status == "failed":
all_ok = False
print("=" * 70)
return all_ok
def main():
parser = argparse.ArgumentParser(
description="Run GAP-CLIP evaluations.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="\n".join(
[
"Available steps:",
" sec51 §5.1 Colour model (Table 1)",
" sec52 §5.2 Category model (Table 2)",
" sec533 §5.3.3 NN accuracy (Table 3)",
" sec536 §5.3.6 Embedding structure tests A/B/C/D (Table 4)",
" annex92 Annex 9.2 Colour heatmaps",
" annex93 Annex 9.3 t-SNE",
" annex94 Annex 9.4 Search demo",
]
),
)
parser.add_argument(
"--steps",
type=str,
default="all",
help=(
"Comma-separated list of steps to run, or 'all' to run everything "
"(default: all). Example: --steps sec51,sec52,sec536"
),
)
parser.add_argument(
"--output",
type=str,
default="evaluation_results",
help="Directory to save results (default: evaluation_results).",
)
args = parser.parse_args()
if args.steps.strip().lower() == "all":
steps = ALL_STEPS
else:
steps = [s.strip() for s in args.steps.split(",") if s.strip()]
runner = EvaluationRunner(output_dir=args.output)
success = runner.run(steps)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()