Leacb4 commited on
Commit
7a5206c
·
verified ·
1 Parent(s): 8278380

Upload evaluation/run_all_evaluations.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/run_all_evaluations.py +140 -35
evaluation/run_all_evaluations.py CHANGED
@@ -22,8 +22,7 @@ Available steps
22
  sec51 §5.1 Colour model accuracy (Table 1)
23
  sec52 §5.2 Category model confusion matrix (Table 2)
24
  sec533 §5.3.3 NN classification accuracy (Table 3)
25
- sec5354 §5.3.4+5 Separation & zero-shot semantic eval
26
- sec536 §5.3.6 Embedding structure Tests A/B/C (Table 4)
27
  annex92 Annex 9.2 Pairwise colour similarity heatmaps
28
  annex93 Annex 9.3 t-SNE visualisations
29
  annex94 Annex 9.4 Fashion search demo
@@ -37,10 +36,95 @@ import traceback
37
  from datetime import datetime
38
  from pathlib import Path
39
 
40
- # Make sure the repo root is on the path so that `config` is importable.
 
41
  sys.path.insert(0, str(Path(__file__).parent.parent))
 
42
 
43
- ALL_STEPS = ["sec51", "sec52", "sec533", "sec5354", "sec536", "annex92", "annex93", "annex94"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  class EvaluationRunner:
@@ -51,6 +135,7 @@ class EvaluationRunner:
51
  self.output_dir.mkdir(exist_ok=True, parents=True)
52
  self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
53
  self.results: dict[str, str] = {} # step -> "ok" | "failed" | "skipped"
 
54
 
55
  # ------------------------------------------------------------------
56
  # Individual section runners (lazy imports to allow partial execution)
@@ -59,54 +144,75 @@ class EvaluationRunner:
59
  def run_sec51(self):
60
  """§5.1 – Colour model accuracy (Table 1)."""
61
  from sec51_color_model_eval import ColorEvaluator
62
- import torch
63
- device = "mps" if torch.backends.mps.is_available() else "cpu"
64
- evaluator = ColorEvaluator(device=device, output_dir=str(self.output_dir / "sec51"))
65
- evaluator.run_full_evaluation()
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def run_sec52(self):
68
  """§5.2 – Category model confusion matrix (Table 2)."""
69
  from sec52_category_model_eval import CategoryModelEvaluator
70
- import torch
71
- device = "mps" if torch.backends.mps.is_available() else "cpu"
72
- evaluator = CategoryModelEvaluator(device=device, directory=str(self.output_dir / "sec52"))
 
 
 
 
 
 
 
 
 
 
73
  evaluator.run_full_evaluation()
74
 
75
  def run_sec533(self):
76
  """§5.3.3 – Nearest-neighbour classification accuracy (Table 3)."""
77
  from sec533_clip_nn_accuracy import ColorHierarchyEvaluator
78
- import torch
79
- device = "mps" if torch.backends.mps.is_available() else "cpu"
80
  evaluator = ColorHierarchyEvaluator(
81
- device=device,
82
  directory=str(self.output_dir / "sec533"),
 
 
 
 
 
 
 
83
  )
84
- max_samples = 10_000
85
- evaluator.evaluate_fashion_mnist(max_samples=max_samples)
86
- evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
87
- evaluator.evaluate_local_validation(max_samples=max_samples)
88
- evaluator.evaluate_baseline_fashion_mnist(max_samples=max_samples)
89
- evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
90
- evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
91
-
92
- def run_sec5354(self):
93
- """§5.3.4+5 – Embedding separation & zero-shot semantic eval."""
94
- # sec5354 has a self-contained __main__ block that handles dataset loading.
95
- import runpy
96
- runpy.run_path(
97
- str(Path(__file__).parent / "sec5354_separation_semantic.py"),
98
- run_name="__main__",
99
- )
100
 
101
  def run_sec536(self):
102
- """§5.3.6 – Embedding structure Tests A/B/C."""
103
  from sec536_embedding_structure import main as sec536_main
104
- sec536_main(selected_tests=["A", "B", "C"])
 
 
 
 
 
 
 
 
105
 
106
  def run_annex92(self):
107
  """Annex 9.2 – Pairwise colour similarity heatmaps."""
108
  # annex92 is a self-contained script; run its __main__ guard.
109
- import importlib, runpy
110
  runpy.run_path(
111
  str(Path(__file__).parent / "annex92_color_heatmaps.py"),
112
  run_name="__main__",
@@ -189,8 +295,7 @@ def main():
189
  " sec51 §5.1 Colour model (Table 1)",
190
  " sec52 §5.2 Category model (Table 2)",
191
  " sec533 §5.3.3 NN accuracy (Table 3)",
192
- " sec5354 §5.3.4+5 Separation & semantic eval",
193
- " sec536 §5.3.6 Embedding structure tests (Table 4)",
194
  " annex92 Annex 9.2 Colour heatmaps",
195
  " annex93 Annex 9.3 t-SNE",
196
  " annex94 Annex 9.4 Search demo",
 
22
  sec51 §5.1 Colour model accuracy (Table 1)
23
  sec52 §5.2 Category model confusion matrix (Table 2)
24
  sec533 §5.3.3 NN classification accuracy (Table 3)
25
+ sec536 §5.3.6 Embedding structure Tests A/B/C/D (Table 4)
 
26
  annex92 Annex 9.2 Pairwise colour similarity heatmaps
27
  annex93 Annex 9.3 t-SNE visualisations
28
  annex94 Annex 9.4 Fashion search demo
 
36
  from datetime import datetime
37
  from pathlib import Path
38
 
39
+ # Make sure the repo root is on the path so that `config` is importable,
40
+ # and the evaluation directory so that secXX modules can be imported.
41
  sys.path.insert(0, str(Path(__file__).parent.parent))
42
+ sys.path.insert(0, str(Path(__file__).parent))
43
 
44
+ ALL_STEPS = ["sec51", "sec52", "sec533", "sec536", "annex92", "annex93", "annex94"]
45
+
46
+
47
+ class ResourceCache:
48
+ """Lazy-loading cache for shared models and raw datasets.
49
+
50
+ Each property is loaded at most once and cached for reuse across
51
+ evaluation sections. This avoids re-downloading Kaggle data (~30s),
52
+ re-loading Fashion-CLIP (~15s) and GAP-CLIP (~20s) multiple times.
53
+ """
54
+
55
+ def __init__(self, device=None):
56
+ import torch
57
+ if device is None:
58
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
59
+ self.device = torch.device(device) if isinstance(device, str) else device
60
+
61
+ self._gap_clip = None
62
+ self._fashion_clip = None
63
+ self._color_model = None
64
+ self._hierarchy_classes = None
65
+ self._kaggle_raw_df = None
66
+ self._local_raw_df = None
67
+
68
+ @property
69
+ def gap_clip(self):
70
+ """(model, processor) for GAP-CLIP."""
71
+ if self._gap_clip is None:
72
+ from config import main_model_path
73
+ from utils.model_loader import load_gap_clip
74
+ print("[ResourceCache] Loading GAP-CLIP...")
75
+ self._gap_clip = load_gap_clip(main_model_path, self.device)
76
+ return self._gap_clip
77
+
78
+ @property
79
+ def fashion_clip(self):
80
+ """(model, processor) for Fashion-CLIP baseline."""
81
+ if self._fashion_clip is None:
82
+ from utils.model_loader import load_baseline_fashion_clip
83
+ print("[ResourceCache] Loading Fashion-CLIP baseline...")
84
+ self._fashion_clip = load_baseline_fashion_clip(self.device)
85
+ return self._fashion_clip
86
+
87
+ @property
88
+ def color_model(self):
89
+ """ColorCLIP model instance."""
90
+ if self._color_model is None:
91
+ from config import color_model_path
92
+ from utils.model_loader import load_color_model
93
+ print("[ResourceCache] Loading ColorCLIP model...")
94
+ self._color_model, _ = load_color_model(color_model_path, self.device)
95
+ return self._color_model
96
+
97
+ @property
98
+ def hierarchy_classes(self):
99
+ """List of hierarchy class names from the hierarchy model checkpoint."""
100
+ if self._hierarchy_classes is None:
101
+ import torch
102
+ from config import hierarchy_model_path
103
+ print("[ResourceCache] Loading hierarchy classes...")
104
+ checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
105
+ self._hierarchy_classes = checkpoint.get('hierarchy_classes', [])
106
+ print(f"[ResourceCache] Found {len(self._hierarchy_classes)} hierarchy classes")
107
+ return self._hierarchy_classes
108
+
109
+ @property
110
+ def kaggle_raw_df(self):
111
+ """Raw Kaggle KAGL DataFrame (downloaded once from HuggingFace)."""
112
+ if self._kaggle_raw_df is None:
113
+ from utils.datasets import download_kaggle_raw_df
114
+ print("[ResourceCache] Downloading Kaggle KAGL dataset...")
115
+ self._kaggle_raw_df = download_kaggle_raw_df()
116
+ return self._kaggle_raw_df
117
+
118
+ @property
119
+ def local_raw_df(self):
120
+ """Raw local validation DataFrame (read once from CSV)."""
121
+ if self._local_raw_df is None:
122
+ import pandas as pd
123
+ from config import local_dataset_path
124
+ print("[ResourceCache] Loading local validation CSV...")
125
+ self._local_raw_df = pd.read_csv(local_dataset_path)
126
+ print(f"[ResourceCache] Local dataset: {len(self._local_raw_df)} rows")
127
+ return self._local_raw_df
128
 
129
 
130
  class EvaluationRunner:
 
135
  self.output_dir.mkdir(exist_ok=True, parents=True)
136
  self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
137
  self.results: dict[str, str] = {} # step -> "ok" | "failed" | "skipped"
138
+ self.cache = ResourceCache()
139
 
140
  # ------------------------------------------------------------------
141
  # Individual section runners (lazy imports to allow partial execution)
 
144
  def run_sec51(self):
145
  """§5.1 – Colour model accuracy (Table 1)."""
146
  from sec51_color_model_eval import ColorEvaluator
147
+ baseline_model, baseline_processor = self.cache.fashion_clip
148
+ evaluator = ColorEvaluator(
149
+ device=self.cache.device,
150
+ directory=str(self.output_dir / "sec51"),
151
+ baseline_model=baseline_model,
152
+ baseline_processor=baseline_processor,
153
+ color_model=self.cache.color_model,
154
+ kaggle_raw_df=self.cache.kaggle_raw_df,
155
+ local_raw_df=self.cache.local_raw_df,
156
+ )
157
+ max_samples = 5000
158
+ evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
159
+ evaluator.evaluate_local_validation(max_samples=max_samples)
160
+ evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
161
+ evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
162
 
163
  def run_sec52(self):
164
  """§5.2 – Category model confusion matrix (Table 2)."""
165
  from sec52_category_model_eval import CategoryModelEvaluator
166
+ gap_model, gap_processor = self.cache.gap_clip
167
+ baseline_model, baseline_processor = self.cache.fashion_clip
168
+ evaluator = CategoryModelEvaluator(
169
+ device=self.cache.device,
170
+ directory=str(self.output_dir / "sec52"),
171
+ gap_clip_model=gap_model,
172
+ gap_clip_processor=gap_processor,
173
+ baseline_model=baseline_model,
174
+ baseline_processor=baseline_processor,
175
+ hierarchy_classes=self.cache.hierarchy_classes,
176
+ kaggle_raw_df=self.cache.kaggle_raw_df,
177
+ local_raw_df=self.cache.local_raw_df,
178
+ )
179
  evaluator.run_full_evaluation()
180
 
181
  def run_sec533(self):
182
  """§5.3.3 – Nearest-neighbour classification accuracy (Table 3)."""
183
  from sec533_clip_nn_accuracy import ColorHierarchyEvaluator
184
+ gap_model, gap_processor = self.cache.gap_clip
185
+ baseline_model, baseline_processor = self.cache.fashion_clip
186
  evaluator = ColorHierarchyEvaluator(
187
+ device=self.cache.device,
188
  directory=str(self.output_dir / "sec533"),
189
+ gap_clip_model=gap_model,
190
+ gap_clip_processor=gap_processor,
191
+ baseline_model=baseline_model,
192
+ baseline_processor=baseline_processor,
193
+ hierarchy_classes=self.cache.hierarchy_classes,
194
+ kaggle_raw_df=self.cache.kaggle_raw_df,
195
+ local_raw_df=self.cache.local_raw_df,
196
  )
197
+ evaluator.run_full_evaluation(max_samples=10_000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def run_sec536(self):
200
+ """§5.3.6 – Embedding structure Tests A/B/C/D."""
201
  from sec536_embedding_structure import main as sec536_main
202
+ gap_model, gap_processor = self.cache.gap_clip
203
+ baseline_model, baseline_processor = self.cache.fashion_clip
204
+ sec536_main(
205
+ selected_tests={"A", "B", "C", "D"},
206
+ model=gap_model,
207
+ processor=gap_processor,
208
+ baseline_model=baseline_model,
209
+ baseline_processor=baseline_processor,
210
+ )
211
 
212
  def run_annex92(self):
213
  """Annex 9.2 – Pairwise colour similarity heatmaps."""
214
  # annex92 is a self-contained script; run its __main__ guard.
215
+ import runpy
216
  runpy.run_path(
217
  str(Path(__file__).parent / "annex92_color_heatmaps.py"),
218
  run_name="__main__",
 
295
  " sec51 §5.1 Colour model (Table 1)",
296
  " sec52 §5.2 Category model (Table 2)",
297
  " sec533 §5.3.3 NN accuracy (Table 3)",
298
+ " sec536 §5.3.6 Embedding structure tests A/B/C/D (Table 4)",
 
299
  " annex92 Annex 9.2 Colour heatmaps",
300
  " annex93 Annex 9.3 t-SNE",
301
  " annex94 Annex 9.4 Search demo",