weijiang99 commited on
Commit
cffeecf
·
verified ·
1 Parent(s): 022c2d7

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,25 @@
1
  ---
2
  title: SpatialBench
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SpatialBench
3
+ emoji: 🧩
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
+ pinned: true
10
+ short_description: Do LLMs Build Spatial World Models? Evidence from Maze Tasks
11
  ---
12
 
13
+ # SpatialBench
14
+
15
+ Evaluation platform for **"Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks"** (ICLR 2026 Workshop).
16
+
17
+ Three tasks probe whether LLMs construct internal spatial representations:
18
+
19
+ | Task | Type | Description |
20
+ |------|------|-------------|
21
+ | **Maze Navigation** | Planning | Find shortest path from start to goal |
22
+ | **Sequential Point Reuse** | Reasoning | Q3 = Q0 — do models reuse earlier computation? |
23
+ | **Compositional Distance** | Reasoning | Compose corner→center distances for Q2 |
24
+
25
+ Models evaluated: Gemini 2.5 Flash, GPT-5 Mini, Claude Haiku 4.5, DeepSeek Chat.
app.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — SpatialBench Gradio application
3
+ -----------------------------------------
4
+ Entrypoint for the HuggingFace Space "SpatialBench".
5
+
6
+ Two tabs:
7
+ 1. Leaderboard — visualize pre-computed results from all three tasks
8
+ 2. Run — launch experiments directly via API keys (no SLURM needed)
9
+ (on HF Space, set API keys as Space Secrets)
10
+
11
+ To run locally:
12
+ cd pipeline/
13
+ python app.py
14
+
15
+ To deploy on HuggingFace Spaces:
16
+ - Set Space Secrets: GEMINI_API_KEY, OPENAI_API_KEY, ANTHROPIC_API_KEY, DEEPSEEK_API_KEY
17
+ - The Space entrypoint is this file (app.py)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import sys
24
+ import threading
25
+ import time
26
+ from pathlib import Path
27
+
28
+ import gradio as gr
29
+ import pandas as pd
30
+ import plotly.express as px
31
+ import plotly.graph_objects as go
32
+
33
+ # Load .env if running locally
34
+ _env = Path(__file__).parent / ".env"
35
+ if _env.exists():
36
+ with open(_env) as _f:
37
+ for _line in _f:
38
+ _line = _line.strip()
39
+ if _line and not _line.startswith("#") and "=" in _line:
40
+ _k, _v = _line.split("=", 1)
41
+ os.environ.setdefault(_k.strip(), _v.strip())
42
+
43
+ # Add repo root to path so pipeline imports work
44
+ sys.path.insert(0, str(Path(__file__).parent))
45
+
46
+ from pipeline.task_builder import load_config, build_all_jobs
47
+ from pipeline.job_monitor import JobMonitor, submit_direct
48
+ from pipeline.results_loader import (
49
+ load_all_results,
50
+ maze_navigation_leaderboard,
51
+ point_reuse_leaderboard,
52
+ compositional_distance_leaderboard,
53
+ )
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Paths
57
+ # ---------------------------------------------------------------------------
58
+ CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
59
+ CFG = load_config(CONFIG_PATH)
60
+ MODEL_CHOICES = list(CFG["models"].keys())
61
+ MODEL_DISPLAY = {k: v["display_name"] for k, v in CFG["models"].items()}
62
+
63
+ # Global job monitor (direct mode only — HF Space has no SLURM)
64
+ _monitor = JobMonitor(mode="direct")
65
+ _monitor_lock = threading.Lock()
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Leaderboard helpers
70
+ # ---------------------------------------------------------------------------
71
+
72
+ def _load_results():
73
+ try:
74
+ return load_all_results(CONFIG_PATH)
75
+ except Exception as e:
76
+ return {"maze_navigation": pd.DataFrame(), "point_reuse": pd.DataFrame(), "compositional_distance": pd.DataFrame()}
77
+
78
+
79
+ def _make_empty_fig(msg: str) -> go.Figure:
80
+ fig = go.Figure()
81
+ fig.add_annotation(text=msg, x=0.5, y=0.5, showarrow=False,
82
+ font=dict(size=16), xref="paper", yref="paper")
83
+ fig.update_layout(xaxis_visible=False, yaxis_visible=False,
84
+ height=300, paper_bgcolor="rgba(0,0,0,0)",
85
+ plot_bgcolor="rgba(0,0,0,0)")
86
+ return fig
87
+
88
+
89
+ # ── Task 1 plots ────────────────────────────────────────────────────────────
90
+
91
+ def plot_task1_accuracy(k_shot: int, input_format: str) -> tuple[go.Figure, pd.DataFrame]:
92
+ results = _load_results()
93
+ df = results["maze_navigation"]
94
+ if df.empty:
95
+ return _make_empty_fig("No Task 1 results found.\nRun experiments first."), pd.DataFrame()
96
+
97
+ sub = df[(df["k_shot"] == k_shot) & (df["input_format"] == input_format)]
98
+ if sub.empty:
99
+ return _make_empty_fig(f"No results for k={k_shot}, format={input_format}"), pd.DataFrame()
100
+
101
+ fig = px.line(
102
+ sub, x="grid_size", y="accuracy",
103
+ color="display_name", line_dash="prompt_strategy",
104
+ markers=True,
105
+ labels={"grid_size": "Grid Size (n×n)", "accuracy": "Accuracy",
106
+ "display_name": "Model", "prompt_strategy": "Strategy"},
107
+ title=f"Task 1 — Maze Navigation ({input_format} format, {k_shot}-shot)",
108
+ color_discrete_sequence=px.colors.qualitative.Set2,
109
+ )
110
+ fig.update_layout(
111
+ yaxis_range=[0, 1],
112
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
113
+ height=420,
114
+ )
115
+ lb = maze_navigation_leaderboard(df, k_shot=k_shot)
116
+ return fig, lb
117
+
118
+
119
+ def plot_task1_format_comparison() -> go.Figure:
120
+ results = _load_results()
121
+ df = results["maze_navigation"]
122
+ if df.empty:
123
+ return _make_empty_fig("No Task 1 results found.")
124
+
125
+ # Average over grid sizes, compare raw vs visual at k=0 with CoT
126
+ sub = df[(df["k_shot"] == 0) & (df["prompt_strategy"] == "cot")]
127
+ if sub.empty:
128
+ sub = df[df["k_shot"] == 0]
129
+ agg = sub.groupby(["display_name", "input_format"])["accuracy"].mean().reset_index()
130
+
131
+ fig = px.bar(
132
+ agg, x="display_name", y="accuracy", color="input_format",
133
+ barmode="group",
134
+ labels={"display_name": "Model", "accuracy": "Mean Accuracy",
135
+ "input_format": "Input Format"},
136
+ title="Task 1 — Raw vs Visual Format (0-shot, CoT, averaged over grid sizes)",
137
+ color_discrete_map={"raw": "#2196F3", "visual": "#FF9800"},
138
+ )
139
+ fig.update_layout(yaxis_range=[0, 1], height=380)
140
+ return fig
141
+
142
+
143
+ # ── Task 2 plots ────────────────────────────────────────────────────────────
144
+
145
+ def plot_task2_q0_q3(grid_size: int) -> tuple[go.Figure, pd.DataFrame]:
146
+ results = _load_results()
147
+ df = results["point_reuse"]
148
+ if df.empty:
149
+ return _make_empty_fig("No Task 2 results found.\nRun experiments first."), pd.DataFrame()
150
+
151
+ sub = df[df["grid_size"] == grid_size]
152
+ if sub.empty:
153
+ return _make_empty_fig(f"No Task 2 results for {grid_size}×{grid_size}"), pd.DataFrame()
154
+
155
+ q0 = sub[sub["question_idx"] == 0].groupby("display_name")["accuracy"].mean().rename("Q0")
156
+ q3 = sub[sub["question_idx"] == 3].groupby("display_name")["accuracy"].mean().rename("Q3")
157
+ plot_df = pd.concat([q0, q3], axis=1).reset_index()
158
+ plot_df_melt = plot_df.melt(id_vars="display_name", var_name="Question", value_name="Accuracy")
159
+
160
+ fig = px.bar(
161
+ plot_df_melt, x="display_name", y="Accuracy", color="Question",
162
+ barmode="group",
163
+ labels={"display_name": "Model"},
164
+ title=f"Task 2 — Q0 vs Q3 Accuracy ({grid_size}×{grid_size} maze)\n"
165
+ "Q3 = Q0 (same question repeated — tests information reuse)",
166
+ color_discrete_map={"Q0": "#4CAF50", "Q3": "#F44336"},
167
+ )
168
+ fig.update_layout(yaxis_range=[0, 1], height=400)
169
+ lb = point_reuse_leaderboard(df)
170
+ return fig, lb
171
+
172
+
173
+ def plot_task2_by_grid() -> go.Figure:
174
+ results = _load_results()
175
+ df = results["point_reuse"]
176
+ if df.empty:
177
+ return _make_empty_fig("No Task 2 results found.")
178
+
179
+ q3 = df[df["question_idx"] == 3].groupby(
180
+ ["display_name", "grid_size"])["accuracy"].mean().reset_index()
181
+
182
+ fig = px.line(
183
+ q3, x="grid_size", y="accuracy", color="display_name",
184
+ markers=True,
185
+ labels={"grid_size": "Grid Size", "accuracy": "Q3 Accuracy",
186
+ "display_name": "Model"},
187
+ title="Task 2 — Q3 Accuracy by Grid Size (Q3 = Q0 repeated)",
188
+ color_discrete_sequence=px.colors.qualitative.Set2,
189
+ )
190
+ fig.update_layout(yaxis_range=[0, 1], height=380)
191
+ return fig
192
+
193
+
194
+ # ── Task 3 plots ────────────────────────────────────────────────────────────
195
+
196
+ def plot_task3_compositional() -> tuple[go.Figure, pd.DataFrame]:
197
+ results = _load_results()
198
+ df = results["compositional_distance"]
199
+ if df.empty:
200
+ return _make_empty_fig("No Task 3 results found.\nRun experiments first."), pd.DataFrame()
201
+
202
+ agg = df.groupby(["display_name", "question_idx"])["accuracy"].mean().reset_index()
203
+ q_labels = {0: "Q0: A→M", 1: "Q1: D→M", 2: "Q2: B→C (compositional)"}
204
+ agg["Question"] = agg["question_idx"].map(q_labels)
205
+
206
+ fig = px.bar(
207
+ agg, x="display_name", y="accuracy", color="Question",
208
+ barmode="group",
209
+ labels={"display_name": "Model", "accuracy": "Accuracy"},
210
+ title="Task 3 — Compositional Distance Comparison\n"
211
+ "Q2 can be composed from Q0+Q1 (corner→center distances)",
212
+ color_discrete_map={
213
+ "Q0: A→M": "#2196F3",
214
+ "Q1: D→M": "#9C27B0",
215
+ "Q2: B→C (compositional)": "#FF5722",
216
+ },
217
+ )
218
+ fig.update_layout(yaxis_range=[0, 1], height=420)
219
+ lb = compositional_distance_leaderboard(df)
220
+ return fig, lb
221
+
222
+
223
+ def plot_task3_by_grid() -> go.Figure:
224
+ results = _load_results()
225
+ df = results["compositional_distance"]
226
+ if df.empty:
227
+ return _make_empty_fig("No Task 3 results found.")
228
+
229
+ q2 = df[df["question_idx"] == 2].groupby(
230
+ ["display_name", "grid_size"])["accuracy"].mean().reset_index()
231
+
232
+ fig = px.line(
233
+ q2, x="grid_size", y="accuracy", color="display_name",
234
+ markers=True,
235
+ labels={"grid_size": "Grid Size", "accuracy": "Q2 Accuracy",
236
+ "display_name": "Model"},
237
+ title="Task 3 — Q2 (Compositional) Accuracy by Grid Size",
238
+ color_discrete_sequence=px.colors.qualitative.Set2,
239
+ )
240
+ fig.update_layout(yaxis_range=[0, 1], height=380)
241
+ return fig
242
+
243
+
244
+ # ---------------------------------------------------------------------------
245
+ # Run-experiments tab
246
+ # ---------------------------------------------------------------------------
247
+
248
+ # Map from env-var name → user-provided key (populated at runtime from form)
249
+ _USER_KEYS: dict[str, str] = {}
250
+ _USER_KEYS_LOCK = threading.Lock()
251
+
252
+
253
+ def launch_experiments(
254
+ tasks: list[str],
255
+ models: list[str],
256
+ grid_sizes_str: str,
257
+ formats: list[str],
258
+ strategies: list[str],
259
+ gemini_key: str,
260
+ openai_key: str,
261
+ anthropic_key: str,
262
+ deepseek_key: str,
263
+ ) -> tuple[str, list[list[str]]]:
264
+ """Called when the user clicks 'Run' in the Gradio UI."""
265
+ # Build a key map from only what the user explicitly typed — never os.environ
266
+ user_keys: dict[str, str] = {}
267
+ if gemini_key.strip():
268
+ user_keys["GEMINI_API_KEY"] = gemini_key.strip()
269
+ if openai_key.strip():
270
+ user_keys["OPENAI_API_KEY"] = openai_key.strip()
271
+ if anthropic_key.strip():
272
+ user_keys["ANTHROPIC_API_KEY"] = anthropic_key.strip()
273
+ if deepseek_key.strip():
274
+ user_keys["DEEPSEEK_API_KEY"] = deepseek_key.strip()
275
+
276
+ if not user_keys:
277
+ return (
278
+ "No API keys provided. Please enter at least one API key to run experiments.",
279
+ [],
280
+ )
281
+
282
+ # Parse grid sizes
283
+ try:
284
+ grid_sizes = [int(g.strip()) for g in grid_sizes_str.split(",") if g.strip()]
285
+ except ValueError:
286
+ return "Invalid grid sizes — enter comma-separated integers, e.g. 5,6,7", []
287
+
288
+ if not tasks:
289
+ return "Select at least one task.", []
290
+ if not models:
291
+ return "Select at least one model.", []
292
+
293
+ # Map display choices back to internal IDs
294
+ task_map = {
295
+ "Maze Navigation": "maze_navigation",
296
+ "Sequential Point Reuse": "point_reuse",
297
+ "Compositional Distance Comparison": "compositional_distance",
298
+ }
299
+ selected_tasks = [task_map[t] for t in tasks if t in task_map]
300
+
301
+ jobs = build_all_jobs(
302
+ cfg=CFG,
303
+ tasks=selected_tasks,
304
+ models=models,
305
+ grid_sizes=grid_sizes or None,
306
+ input_formats=formats or None,
307
+ prompt_strategies=strategies or None,
308
+ config_path=CONFIG_PATH,
309
+ )
310
+
311
+ if not jobs:
312
+ return "No jobs matched the selected filters.", []
313
+
314
+ launched = 0
315
+ skipped = 0
316
+ skipped_models: list[str] = []
317
+ with _monitor_lock:
318
+ for job in jobs:
319
+ # Only use the key the user provided — never fall back to server env
320
+ api_key = user_keys.get(job.api_key_env, "")
321
+ if not api_key:
322
+ skipped += 1
323
+ skipped_models.append(job.model)
324
+ continue
325
+ job.output_dir.mkdir(parents=True, exist_ok=True)
326
+ proc = submit_direct(
327
+ cmd=job.python_cmd,
328
+ working_dir=str(job.working_dir),
329
+ env={job.api_key_env: api_key},
330
+ )
331
+ _monitor.add_direct(
332
+ proc=proc,
333
+ label=job.label,
334
+ task_id=job.task_id,
335
+ model=job.model,
336
+ output_dir=str(job.output_dir),
337
+ )
338
+ launched += 1
339
+ time.sleep(1) # avoid API rate limits on burst start
340
+
341
+ status_msg = f"Launched {launched} job(s)."
342
+ if skipped:
343
+ missing = sorted(set(skipped_models))
344
+ status_msg += (
345
+ f" Skipped {skipped} job(s) for {', '.join(missing)} "
346
+ f"— no API key provided for those models."
347
+ )
348
+
349
+ return status_msg, _monitor.as_table()
350
+
351
+
352
+ def refresh_status() -> tuple[list[list[str]], str]:
353
+ _monitor.refresh()
354
+ summary = _monitor.summary()
355
+ counts = summary["counts"]
356
+ msg = " ".join(f"{s}: {n}" for s, n in counts.items()) or "No jobs submitted yet."
357
+ return _monitor.as_table(), msg
358
+
359
+
360
+ # ---------------------------------------------------------------------------
361
+ # Gradio UI
362
+ # ---------------------------------------------------------------------------
363
+
364
+ PAPER_ABSTRACT = """
365
+ **Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks**
366
+
367
+ We systematically evaluate the spatial understanding of large language models through maze tasks—a
368
+ controlled testing context requiring multi-step planning and spatial abstraction. Across experiments
369
+ with Gemini-2.5-Flash, GPT-5-mini, Claude-Haiku-4.5, and DeepSeek-Chat, we uncover significant
370
+ discrepancies in spatial reasoning that challenge assumptions about LLM planning capabilities.
371
+
372
+ Key findings:
373
+ - **Representation sensitivity**: Gemini drops from 86% (raw tokenized) to 34% (visual grid) on 5×5 mazes with CoT
374
+ - **Prompting dependency**: Claude-Haiku fails completely without CoT, recovers to 78% with it
375
+ - **No spatial memory**: Models treat sequential questions independently, failing to reuse computed spatial knowledge
376
+ """
377
+
378
+ CSS = """
379
+ .leaderboard-table { font-size: 0.9em; }
380
+ .status-badge-running { color: #2196F3; font-weight: bold; }
381
+ .status-badge-completed { color: #4CAF50; font-weight: bold; }
382
+ .status-badge-failed { color: #F44336; font-weight: bold; }
383
+ footer { display: none !important; }
384
+ """
385
+
386
+ def build_ui() -> gr.Blocks:
387
+ with gr.Blocks(
388
+ title="SpatialBench — Do LLMs Build Spatial World Models?",
389
+ css=CSS,
390
+ theme=gr.themes.Soft(primary_hue="blue"),
391
+ ) as demo:
392
+
393
+ gr.Markdown("# 🧩 SpatialBench")
394
+ gr.Markdown(
395
+ "**Evaluating Spatial World Models in Large Language Models** · "
396
+ "[Paper (ICLR 2026 Workshop)](https://arxiv.org/abs/...) · "
397
+ "[Code](https://github.com/...)"
398
+ )
399
+
400
+ with gr.Tabs():
401
+
402
+ # ================================================================
403
+ # Tab 1: Leaderboard
404
+ # ================================================================
405
+ with gr.Tab("📊 Leaderboard"):
406
+ gr.Markdown(PAPER_ABSTRACT)
407
+
408
+ gr.Markdown("---")
409
+ gr.Markdown("## Task 1 — Maze Navigation (Planning)")
410
+ gr.Markdown(
411
+ "Models find shortest paths through mazes. "
412
+ "Two input formats: **raw** tokenized adjacency lists vs **visual** character grids."
413
+ )
414
+
415
+ with gr.Row():
416
+ t1_k = gr.Radio(
417
+ choices=[0, 3, 5], value=0, label="K-shot",
418
+ info="Number of in-context examples",
419
+ )
420
+ t1_fmt = gr.Radio(
421
+ choices=["raw", "visual"], value="raw", label="Input Format",
422
+ )
423
+
424
+ t1_plot = gr.Plot(label="Accuracy by Grid Size")
425
+ t1_lb = gr.Dataframe(
426
+ label="Leaderboard (mean accuracy across grid sizes)",
427
+ elem_classes=["leaderboard-table"],
428
+ )
429
+ t1_fmt_plot = gr.Plot(label="Raw vs Visual Format Comparison")
430
+
431
+ def refresh_task1(k, fmt):
432
+ fig, lb = plot_task1_accuracy(int(k), fmt)
433
+ fmt_fig = plot_task1_format_comparison()
434
+ return fig, lb, fmt_fig
435
+
436
+ for inp in [t1_k, t1_fmt]:
437
+ inp.change(
438
+ refresh_task1, inputs=[t1_k, t1_fmt],
439
+ outputs=[t1_plot, t1_lb, t1_fmt_plot],
440
+ )
441
+
442
+ gr.Markdown("---")
443
+ gr.Markdown("## Task 2 — Sequential Reasoning with Point Reuse")
444
+ gr.Markdown(
445
+ "Models answer 4 proximity questions. **Q3 = Q0** (same question repeated). "
446
+ "Do models reuse their earlier computation, or start from scratch?"
447
+ )
448
+
449
+ t2_grid = gr.Slider(minimum=5, maximum=9, step=1, value=5,
450
+ label="Grid Size")
451
+ t2_plot = gr.Plot(label="Q0 vs Q3 Accuracy")
452
+ t2_grid_plot = gr.Plot(label="Q3 Accuracy Across Grid Sizes")
453
+ t2_lb = gr.Dataframe(
454
+ label="Leaderboard",
455
+ elem_classes=["leaderboard-table"],
456
+ )
457
+
458
+ def refresh_task2(gs):
459
+ fig, lb = plot_task2_q0_q3(int(gs))
460
+ grid_fig = plot_task2_by_grid()
461
+ return fig, grid_fig, lb
462
+
463
+ t2_grid.change(
464
+ refresh_task2, inputs=[t2_grid],
465
+ outputs=[t2_plot, t2_grid_plot, t2_lb],
466
+ )
467
+
468
+ gr.Markdown("---")
469
+ gr.Markdown("## Task 3 — Compositional Distance Comparison")
470
+ gr.Markdown(
471
+ "Models answer 3 questions about maze corners (A, B, C, D) and center M. "
472
+ "**Q2** (B→C) can potentially be composed from Q0 (A→M) and Q1 (D→M). "
473
+ "Δ = Q2 accuracy − avg(Q0, Q1)."
474
+ )
475
+
476
+ t3_plot = gr.Plot(label="Q0 / Q1 / Q2 Accuracy by Model")
477
+ t3_grid_plot = gr.Plot(label="Q2 Accuracy Across Grid Sizes")
478
+ t3_lb = gr.Dataframe(
479
+ label="Leaderboard (Δ shows compositional benefit)",
480
+ elem_classes=["leaderboard-table"],
481
+ )
482
+
483
+ with gr.Row():
484
+ refresh_lb_btn = gr.Button("🔄 Refresh Results", variant="secondary")
485
+
486
+ def refresh_all_leaderboard(_=None):
487
+ t1_fig, t1_table = plot_task1_accuracy(0, "raw")
488
+ t1_ff = plot_task1_format_comparison()
489
+ t2_fig, t2_lb_table = plot_task2_q0_q3(5)
490
+ t2_gfig = plot_task2_by_grid()
491
+ t3_fig, t3_lb_table = plot_task3_compositional()
492
+ t3_gfig = plot_task3_by_grid()
493
+ return (
494
+ t1_fig, t1_table, t1_ff,
495
+ t2_fig, t2_gfig, t2_lb_table,
496
+ t3_fig, t3_gfig, t3_lb_table,
497
+ )
498
+
499
+ refresh_lb_btn.click(
500
+ refresh_all_leaderboard,
501
+ outputs=[
502
+ t1_plot, t1_lb, t1_fmt_plot,
503
+ t2_plot, t2_grid_plot, t2_lb,
504
+ t3_plot, t3_grid_plot, t3_lb,
505
+ ],
506
+ )
507
+
508
+ # Initial load
509
+ demo.load(
510
+ refresh_all_leaderboard,
511
+ outputs=[
512
+ t1_plot, t1_lb, t1_fmt_plot,
513
+ t2_plot, t2_grid_plot, t2_lb,
514
+ t3_plot, t3_grid_plot, t3_lb,
515
+ ],
516
+ )
517
+
518
+ # ================================================================
519
+ # Tab 2: Run Experiments
520
+ # ================================================================
521
+ with gr.Tab("⚡ Run Experiments"):
522
+ gr.Markdown(
523
+ "## Launch Experiments\n"
524
+ "Experiments call LLM APIs directly — no compute cluster needed.\n\n"
525
+ "> **Your API keys are used only for your session and are never stored or logged.** \n"
526
+ "> Enter keys only for the model(s) you want to evaluate. "
527
+ "Jobs for models without a key will be skipped."
528
+ )
529
+
530
+ with gr.Row():
531
+ with gr.Column(scale=2):
532
+ # Task / model / grid selection
533
+ run_tasks = gr.CheckboxGroup(
534
+ choices=[
535
+ "Maze Navigation",
536
+ "Sequential Point Reuse",
537
+ "Compositional Distance Comparison",
538
+ ],
539
+ value=["Maze Navigation"],
540
+ label="Tasks",
541
+ )
542
+ run_models = gr.CheckboxGroup(
543
+ choices=MODEL_CHOICES,
544
+ value=["gemini-2.5-flash"],
545
+ label="Models",
546
+ )
547
+ run_grids = gr.Textbox(
548
+ value="5,6,7",
549
+ label="Grid Sizes",
550
+ info="Comma-separated integers. Maze dataset supports 5–9 (and beyond if regenerated).",
551
+ )
552
+ with gr.Row():
553
+ run_formats = gr.CheckboxGroup(
554
+ choices=["raw", "visual"],
555
+ value=["raw"],
556
+ label="Input Formats (Task 1 only)",
557
+ )
558
+ run_strategies = gr.CheckboxGroup(
559
+ choices=["base", "cot", "reasoning"],
560
+ value=["cot"],
561
+ label="Prompt Strategies",
562
+ )
563
+
564
+ with gr.Column(scale=1):
565
+ gr.Markdown("### API Keys")
566
+ gr.Markdown(
567
+ "Enter the key(s) for the model(s) you selected. "
568
+ "Keys are used only for this session."
569
+ )
570
+ gemini_key = gr.Textbox(
571
+ label="GEMINI_API_KEY", type="password", placeholder="AIza...",
572
+ )
573
+ openai_key = gr.Textbox(
574
+ label="OPENAI_API_KEY", type="password", placeholder="sk-...",
575
+ )
576
+ anthropic_key = gr.Textbox(
577
+ label="ANTHROPIC_API_KEY", type="password", placeholder="sk-ant-...",
578
+ )
579
+ deepseek_key = gr.Textbox(
580
+ label="DEEPSEEK_API_KEY", type="password",
581
+ )
582
+
583
+ with gr.Row():
584
+ run_btn = gr.Button("🚀 Launch Experiments", variant="primary", scale=2)
585
+ refresh_btn = gr.Button("🔄 Refresh Status", scale=1)
586
+
587
+ launch_msg = gr.Textbox(label="Launch Status", interactive=False)
588
+
589
+ job_table = gr.Dataframe(
590
+ headers=["Task", "Model", "Label", "Status", "Elapsed", "Started"],
591
+ label="Job Status",
592
+ interactive=False,
593
+ wrap=True,
594
+ )
595
+ status_summary = gr.Textbox(
596
+ label="Summary", interactive=False,
597
+ )
598
+
599
+ run_btn.click(
600
+ launch_experiments,
601
+ inputs=[
602
+ run_tasks, run_models, run_grids,
603
+ run_formats, run_strategies,
604
+ gemini_key, openai_key, anthropic_key, deepseek_key,
605
+ ],
606
+ outputs=[launch_msg, job_table],
607
+ )
608
+
609
+ refresh_btn.click(
610
+ refresh_status,
611
+ outputs=[job_table, status_summary],
612
+ )
613
+
614
+ # ================================================================
615
+ # Tab 3: About
616
+ # ================================================================
617
+ with gr.Tab("ℹ️ About"):
618
+ gr.Markdown("""
619
+ ## About SpatialBench
620
+
621
+ SpatialBench is the evaluation platform accompanying the paper:
622
+
623
+ > **Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks**
624
+ > *Under review at ICLR 2026 Workshop*
625
+
626
+ ### Three Tasks
627
+
628
+ | Task | Type | What it tests |
629
+ |------|------|---------------|
630
+ | **Task 1: Maze Navigation** | Planning | Find shortest path from start to goal |
631
+ | **Task 2: Sequential Point Reuse** | Reasoning | Reuse Q0 computation when Q3=Q0 |
632
+ | **Task 3: Compositional Distance** | Reasoning | Compose corner→center distances for Q2 |
633
+
634
+ ### Input Representations
635
+
636
+ - **Raw (tokenized)**: `<ADJLIST_START> (0,0) <--> (0,1) ... <ADJLIST_END>`
637
+ - **Visual (grid)**: `Row 0: ['.', 'S', '.', '#'] Row 1: ['#', '.', '.', 'E']`
638
+
639
+ ### Models Evaluated
640
+
641
+ | Model | Provider |
642
+ |-------|----------|
643
+ | Gemini 2.5 Flash | Google |
644
+ | GPT-5 Mini | OpenAI |
645
+ | Claude Haiku 4.5 | Anthropic |
646
+ | DeepSeek Chat | DeepSeek |
647
+
648
+ ### Grid Sizes
649
+
650
+ Experiments run on n×n grids for n ∈ {5, 6, 7, 8, 9} by default.
651
+ The underlying `maze-dataset` library supports larger grids — adjust in the **Run** tab.
652
+
653
+ ### Adding a New Model
654
+
655
+ Edit `pipeline/configs/experiments.yaml`:
656
+ ```yaml
657
+ models:
658
+ your-model-id:
659
+ api_key_env: YOUR_API_KEY_ENV_VAR
660
+ display_name: "Your Model Name"
661
+ ```
662
+ Then add inference support in `utils/llm_inference.py`.
663
+
664
+ ### Citation
665
+ ```bibtex
666
+ @inproceedings{spatialbench2026,
667
+ title = {Do {LLMs} Build Spatial World Models? Evidence from Grid-World Maze Tasks},
668
+ author = {Anonymous},
669
+ booktitle = {ICLR 2026 Workshop},
670
+ year = {2026},
671
+ }
672
+ ```
673
+ """)
674
+
675
+ return demo
676
+
677
+
678
+ # ---------------------------------------------------------------------------
679
+ # Entry point
680
+ # ---------------------------------------------------------------------------
681
+
682
+ if __name__ == "__main__":
683
+ demo = build_ui()
684
+ demo.launch(
685
+ server_name="0.0.0.0",
686
+ share=False,
687
+ show_error=True,
688
+ )
configs/experiments.yaml ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # SpatialBench Experiment Configuration
3
+ # =============================================================================
4
+ # This file is the single source of truth for all experiments.
5
+ # Add a new model by adding an entry under `models`.
6
+ # Add a new grid size by extending `grid_sizes` in any task.
7
+ # All paths are relative to llm-maze-solver/ (the repo root).
8
+ # =============================================================================
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Global defaults — overridden per-task or per-experiment as needed
12
+ # ---------------------------------------------------------------------------
13
+ defaults:
14
+ n_test_mazes: 50
15
+ seed: 42
16
+ temperature: 0.1
17
+ max_tokens: 8192
18
+ sbatch:
19
+ cpus: 2
20
+ mem: "8G"
21
+ time: "10:00:00"
22
+ partition: "short"
23
+ log_dir: "maze-solver/eval_llm_logs"
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Models
27
+ # Each entry defines the model identifier used in API calls and the
28
+ # environment variable that must hold the API key.
29
+ # ---------------------------------------------------------------------------
30
+ models:
31
+ gemini-2.5-flash:
32
+ api_key_env: GEMINI_API_KEY
33
+ display_name: "Gemini 2.5 Flash"
34
+ gpt-5-mini:
35
+ api_key_env: OPENAI_API_KEY
36
+ display_name: "GPT-5 Mini"
37
+ claude-haiku-4-5:
38
+ api_key_env: ANTHROPIC_API_KEY
39
+ display_name: "Claude Haiku 4.5"
40
+ deepseek-chat:
41
+ api_key_env: DEEPSEEK_API_KEY
42
+ display_name: "DeepSeek Chat"
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Maze Navigation (Planning)
46
+ # Paper: Table 1, Table 5 (3-shot), Table 6 (5-shot)
47
+ # Script: maze-solver/eval_llm_maze_solver.py
48
+ # ---------------------------------------------------------------------------
49
+ maze_navigation:
50
+ description: >
51
+ Models find shortest paths through mazes represented in two formats
52
+ (raw tokenized adjacency lists vs visual character grids), tested
53
+ across k-shot settings and three prompting strategies.
54
+ script: "maze-solver/eval_llm_maze_solver.py"
55
+ working_dir: "maze-solver"
56
+ output_base: "maze-solver/llm-maze-evaluation-results"
57
+
58
+ # Grid sizes: paper used 5-9; extend freely up to maze-dataset limits
59
+ grid_sizes: [5, 6, 7, 8, 9]
60
+
61
+ # Input representations
62
+ input_formats: ["raw", "visual"]
63
+
64
+ # Prompting strategies (maps to script flags)
65
+ prompt_strategies:
66
+ base:
67
+ flags: []
68
+ display_name: "Base"
69
+ cot:
70
+ flags: ["--chain_of_thought"]
71
+ display_name: "Chain-of-Thought"
72
+ reasoning:
73
+ flags: ["--reasoning"]
74
+ display_name: "Post-hoc Reasoning"
75
+
76
+ # K-shot values tested simultaneously in one script run
77
+ k_shots: "0,3,5"
78
+
79
+ # Fixed params
80
+ maze_type: "cycles"
81
+ percolation_p: 0.2
82
+ visualize: true
83
+
84
+ sbatch:
85
+ time: "10:00:00"
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Sequential Reasoning with Point Reuse (Q3 = Q0)
89
+ # Paper: Table 2, Table 7
90
+ # Script: maze-solver/spatial_reasoning/eval_proximity_comparison.py
91
+ # ---------------------------------------------------------------------------
92
+ point_reuse:
93
+ description: >
94
+ Models answer four sequential proximity questions about the same maze.
95
+ Q3 is identical to Q0, probing whether models reuse previously
96
+ computed spatial information or treat each question independently.
97
+ script: "maze-solver/spatial_reasoning/eval_proximity_comparison.py"
98
+ working_dir: "spatial_reasoning/spatial_reasoning_experiments"
99
+
100
+ # Paper used 5-9; extend freely
101
+ grid_sizes: [5, 6, 7, 8, 9]
102
+
103
+ input_format: "raw"
104
+ strategy: "point_reuse"
105
+ reuse_pattern: "last_first_same"
106
+ n_questions_per_maze: 4
107
+ sequential_questions: true
108
+
109
+ # Prompting strategies
110
+ prompt_strategies:
111
+ base:
112
+ prompt_type: "baseline"
113
+ display_name: "Base"
114
+ cot:
115
+ prompt_type: "cot"
116
+ display_name: "Chain-of-Thought"
117
+ reasoning:
118
+ prompt_type: "reasoning"
119
+ display_name: "Post-hoc Reasoning"
120
+
121
+ output_base: "spatial_reasoning/spatial-reasoning-results-point-reuse-q3-q0"
122
+ visualize: true
123
+ save_details: true
124
+
125
+ sbatch:
126
+ time: "10:30:00"
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # Compositional Distance Comparison
130
+ # Paper: Table 3, Table 8, Table 9
131
+ # Script: maze-solver/spatial_reasoning/eval_extended_experiments.py
132
+ # Corner pattern: corners_to_center (Q0: top-left→center,
133
+ # Q1: bottom-right→center,
134
+ # Q2: corner→corner compositional)
135
+ # ---------------------------------------------------------------------------
136
+ compositional_distance:
137
+ description: >
138
+ Models answer three questions about maze corners (A=top-left,
139
+ B=top-right, C=bottom-left, D=bottom-right) and center M.
140
+ Q2 can be composed from information established in Q0 and Q1,
141
+ probing whether models build cumulative spatial knowledge.
142
+ script: "maze-solver/spatial_reasoning/eval_extended_experiments.py"
143
+ working_dir: "spatial_reasoning/spatial_reasoning_experiments"
144
+
145
+ # Paper reported 5-9 in Tables 8/9; scripts originally only ran 5-7
146
+ # Extended to match paper
147
+ grid_sizes: [5, 6, 7, 8, 9]
148
+
149
+ input_format: "raw"
150
+ strategy: "orthogonal"
151
+ corner_pattern: "corners_to_center" # matches paper Q0/Q1/Q2 design
152
+ n_questions_per_maze: 3
153
+
154
+ # Prompting strategies
155
+ prompt_strategies:
156
+ base:
157
+ prompt_type: "baseline"
158
+ display_name: "Base"
159
+ cot:
160
+ prompt_type: "cot"
161
+ display_name: "Chain-of-Thought"
162
+ reasoning:
163
+ prompt_type: "reasoning"
164
+ display_name: "Post-hoc Reasoning"
165
+
166
+ output_base: "spatial_reasoning/spatial-reasoning-results-orthogonal"
167
+ visualize: true
168
+ save_details: true
169
+
170
+ sbatch:
171
+ time: "06:30:00"
pipeline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ "SpatialBench pipeline modules."
pipeline/job_monitor.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ job_monitor.py
3
+ --------------
4
+ Tracks and displays the status of experiment jobs.
5
+
6
+ Supports two backends:
7
+ - SLURM : polls `squeue` for cluster jobs (used when running locally)
8
+ - Direct : tracks subprocess-launched jobs (used when running via API keys
9
+ on HF Space or without a cluster)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import subprocess
15
+ import time
16
+ import threading
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime
19
+ from enum import Enum
20
+ from typing import Callable
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Status model
25
+ # ---------------------------------------------------------------------------
26
+
27
+ class JobStatus(str, Enum):
28
+ PENDING = "pending"
29
+ RUNNING = "running"
30
+ COMPLETED = "completed"
31
+ FAILED = "failed"
32
+ CANCELLED = "cancelled"
33
+ UNKNOWN = "unknown"
34
+
35
+
36
+ @dataclass
37
+ class JobRecord:
38
+ job_id: str # SLURM job-id or process PID (as str)
39
+ label: str # human-readable experiment label
40
+ task_id: str
41
+ model: str
42
+ status: JobStatus = JobStatus.PENDING
43
+ submitted_at: datetime = field(default_factory=datetime.now)
44
+ finished_at: datetime | None = None
45
+ output_dir: str = ""
46
+ log_out: str = ""
47
+ log_err: str = ""
48
+
49
+ def elapsed(self) -> str:
50
+ end = self.finished_at or datetime.now()
51
+ secs = int((end - self.submitted_at).total_seconds())
52
+ h, rem = divmod(secs, 3600)
53
+ m, s = divmod(rem, 60)
54
+ return f"{h:02d}:{m:02d}:{s:02d}"
55
+
56
+ def as_dict(self) -> dict:
57
+ return {
58
+ "job_id": self.job_id,
59
+ "label": self.label,
60
+ "task_id": self.task_id,
61
+ "model": self.model,
62
+ "status": self.status.value,
63
+ "submitted_at": self.submitted_at.strftime("%Y-%m-%d %H:%M:%S"),
64
+ "elapsed": self.elapsed(),
65
+ "output_dir": self.output_dir,
66
+ }
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # SLURM monitor
71
+ # ---------------------------------------------------------------------------
72
+
73
+ # Map squeue state codes → JobStatus
74
+ _SLURM_STATE_MAP = {
75
+ "PD": JobStatus.PENDING,
76
+ "R": JobStatus.RUNNING,
77
+ "CG": JobStatus.RUNNING,
78
+ "CD": JobStatus.COMPLETED,
79
+ "F": JobStatus.FAILED,
80
+ "CA": JobStatus.CANCELLED,
81
+ "TO": JobStatus.FAILED,
82
+ "OOM": JobStatus.FAILED,
83
+ }
84
+
85
+
86
+ def _query_slurm(job_ids: list[str]) -> dict[str, JobStatus]:
87
+ """Return {job_id: JobStatus} for a batch of SLURM job ids."""
88
+ if not job_ids:
89
+ return {}
90
+ try:
91
+ result = subprocess.run(
92
+ ["squeue", "--jobs", ",".join(job_ids), "--format=%i %t", "--noheader"],
93
+ capture_output=True, text=True, timeout=15,
94
+ )
95
+ statuses: dict[str, JobStatus] = {}
96
+ for line in result.stdout.strip().splitlines():
97
+ parts = line.split()
98
+ if len(parts) >= 2:
99
+ jid, state = parts[0], parts[1]
100
+ statuses[jid] = _SLURM_STATE_MAP.get(state, JobStatus.UNKNOWN)
101
+ return statuses
102
+ except Exception:
103
+ return {}
104
+
105
+
106
+ def submit_sbatch(script_text: str, script_path: str) -> str | None:
107
+ """Write script_text to script_path, submit via sbatch, return job_id."""
108
+ with open(script_path, "w") as f:
109
+ f.write(script_text)
110
+ try:
111
+ result = subprocess.run(
112
+ ["sbatch", script_path],
113
+ capture_output=True, text=True, timeout=30,
114
+ )
115
+ # sbatch output: "Submitted batch job 12345"
116
+ for token in result.stdout.split():
117
+ if token.isdigit():
118
+ return token
119
+ except Exception:
120
+ pass
121
+ return None
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Direct (subprocess) monitor
126
+ # ---------------------------------------------------------------------------
127
+
128
+ def submit_direct(
129
+ cmd: list[str],
130
+ working_dir: str,
131
+ env: dict | None = None,
132
+ on_finish: Callable[[int], None] | None = None,
133
+ ) -> subprocess.Popen:
134
+ """Launch a job as a subprocess and optionally call on_finish(returncode)."""
135
+ import os
136
+ proc_env = os.environ.copy()
137
+ if env:
138
+ proc_env.update(env)
139
+
140
+ proc = subprocess.Popen(
141
+ cmd,
142
+ cwd=working_dir,
143
+ env=proc_env,
144
+ stdout=subprocess.PIPE,
145
+ stderr=subprocess.PIPE,
146
+ text=True,
147
+ )
148
+
149
+ if on_finish:
150
+ def _wait():
151
+ proc.wait()
152
+ on_finish(proc.returncode)
153
+ threading.Thread(target=_wait, daemon=True).start()
154
+
155
+ return proc
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # JobMonitor — unified tracker
160
+ # ---------------------------------------------------------------------------
161
+
162
+ class JobMonitor:
163
+ """
164
+ Tracks a collection of JobRecords.
165
+
166
+ Usage (SLURM):
167
+ monitor = JobMonitor(mode="slurm")
168
+ record = monitor.add(job_id="12345", label="Task1|gemini|raw|cot", ...)
169
+ monitor.refresh() # updates statuses from squeue
170
+ monitor.wait_all() # blocks until all done
171
+
172
+ Usage (direct):
173
+ monitor = JobMonitor(mode="direct")
174
+ proc = submit_direct(cmd, wdir, env)
175
+ record = monitor.add_direct(proc, label=..., ...)
176
+ monitor.wait_all()
177
+ """
178
+
179
+ def __init__(self, mode: str = "slurm"):
180
+ assert mode in ("slurm", "direct")
181
+ self.mode = mode
182
+ self._records: dict[str, JobRecord] = {} # job_id → JobRecord
183
+ self._procs: dict[str, subprocess.Popen] = {}
184
+ self._lock = threading.Lock()
185
+
186
+ # -- adding jobs --------------------------------------------------------
187
+
188
+ def add(
189
+ self,
190
+ job_id: str,
191
+ label: str,
192
+ task_id: str,
193
+ model: str,
194
+ output_dir: str = "",
195
+ log_out: str = "",
196
+ log_err: str = "",
197
+ ) -> JobRecord:
198
+ record = JobRecord(
199
+ job_id=job_id, label=label,
200
+ task_id=task_id, model=model,
201
+ output_dir=output_dir, log_out=log_out, log_err=log_err,
202
+ )
203
+ with self._lock:
204
+ self._records[job_id] = record
205
+ return record
206
+
207
+ def add_direct(
208
+ self,
209
+ proc: subprocess.Popen,
210
+ label: str,
211
+ task_id: str,
212
+ model: str,
213
+ output_dir: str = "",
214
+ ) -> JobRecord:
215
+ job_id = str(proc.pid)
216
+ record = self.add(
217
+ job_id=job_id, label=label,
218
+ task_id=task_id, model=model, output_dir=output_dir,
219
+ )
220
+ record.status = JobStatus.RUNNING
221
+ with self._lock:
222
+ self._procs[job_id] = proc
223
+
224
+ def _monitor():
225
+ proc.wait()
226
+ with self._lock:
227
+ record.status = (
228
+ JobStatus.COMPLETED if proc.returncode == 0
229
+ else JobStatus.FAILED
230
+ )
231
+ record.finished_at = datetime.now()
232
+
233
+ threading.Thread(target=_monitor, daemon=True).start()
234
+ return record
235
+
236
+ # -- status refreshing --------------------------------------------------
237
+
238
+ def refresh(self) -> None:
239
+ """Update statuses from SLURM (no-op for direct mode)."""
240
+ if self.mode != "slurm":
241
+ return
242
+ with self._lock:
243
+ active_ids = [
244
+ jid for jid, r in self._records.items()
245
+ if r.status in (JobStatus.PENDING, JobStatus.RUNNING)
246
+ ]
247
+ if not active_ids:
248
+ return
249
+ statuses = _query_slurm(active_ids)
250
+ with self._lock:
251
+ for jid, status in statuses.items():
252
+ if jid in self._records:
253
+ old = self._records[jid].status
254
+ self._records[jid].status = status
255
+ if old != status and status in (
256
+ JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED
257
+ ):
258
+ self._records[jid].finished_at = datetime.now()
259
+ # Any job no longer appearing in squeue is done
260
+ for jid in active_ids:
261
+ if jid not in statuses and jid in self._records:
262
+ r = self._records[jid]
263
+ if r.status == JobStatus.RUNNING:
264
+ r.status = JobStatus.COMPLETED
265
+ r.finished_at = datetime.now()
266
+
267
+ def wait_all(self, poll_interval: int = 30, callback: Callable | None = None) -> None:
268
+ """Block until all jobs are in a terminal state."""
269
+ while True:
270
+ self.refresh()
271
+ active = self.active_jobs()
272
+ if callback:
273
+ callback(self.summary())
274
+ if not active:
275
+ break
276
+ time.sleep(poll_interval)
277
+
278
+ # -- queries ------------------------------------------------------------
279
+
280
+ def active_jobs(self) -> list[JobRecord]:
281
+ with self._lock:
282
+ return [
283
+ r for r in self._records.values()
284
+ if r.status in (JobStatus.PENDING, JobStatus.RUNNING)
285
+ ]
286
+
287
+ def all_records(self) -> list[JobRecord]:
288
+ with self._lock:
289
+ return list(self._records.values())
290
+
291
+ def summary(self) -> dict:
292
+ records = self.all_records()
293
+ counts: dict[str, int] = {}
294
+ for r in records:
295
+ counts[r.status.value] = counts.get(r.status.value, 0) + 1
296
+ return {
297
+ "total": len(records),
298
+ "counts": counts,
299
+ "records": [r.as_dict() for r in records],
300
+ }
301
+
302
+ def as_table(self) -> list[list[str]]:
303
+ """Return rows suitable for a Gradio Dataframe component."""
304
+ records = self.all_records()
305
+ rows = []
306
+ for r in records:
307
+ rows.append([
308
+ r.task_id, r.model, r.label,
309
+ r.status.value, r.elapsed(),
310
+ r.submitted_at.strftime("%H:%M:%S"),
311
+ ])
312
+ return rows
313
+
314
+ TABLE_HEADERS = ["Task", "Model", "Label", "Status", "Elapsed", "Started"]
pipeline/results_loader.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ results_loader.py
3
+ -----------------
4
+ Scans experiment output directories and assembles results into pandas
5
+ DataFrames ready for display in the Gradio leaderboard.
6
+
7
+ Output directory conventions (from experiments.yaml):
8
+ Task 1: <output_base>/<model>/<fmt>_input_<strat>/
9
+ → results_{grid_size}x{grid_size}_k{k}.csv OR summary.json
10
+ Task 2: <output_base>/<model>/point_reuse_q3q0_<strat>/
11
+ → proximity_comparison_results.csv
12
+ Task 3: <output_base>/<model>/orthogonal_corners_to_center_<strat>/
13
+ → results.csv OR summary_stats.json
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import pandas as pd
23
+ import yaml
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Config helpers
28
+ # ---------------------------------------------------------------------------
29
+
30
+ def load_config(config_path: str | Path) -> dict:
31
+ with open(config_path) as f:
32
+ return yaml.safe_load(f)
33
+
34
+
35
+ def _model_display(cfg: dict, model_id: str) -> str:
36
+ return cfg["models"].get(model_id, {}).get("display_name", model_id)
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Task 1 results
41
+ # ---------------------------------------------------------------------------
42
+
43
+ def load_maze_navigation_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
44
+ """
45
+ Scan Task 1 output dirs and return a DataFrame with columns:
46
+ model, display_name, input_format, prompt_strategy, grid_size, k_shot, accuracy
47
+ """
48
+ task = cfg["maze_navigation"]
49
+ base = repo_root / task["output_base"]
50
+ rows = []
51
+
52
+ for model_id, model_meta in cfg["models"].items():
53
+ display = model_meta["display_name"]
54
+ for fmt in task["input_formats"]:
55
+ for strat in task["prompt_strategies"]:
56
+ subdir = base / model_id.replace(".", "_").replace("-", "_") / f"{fmt}_input_{strat}"
57
+ if not subdir.exists():
58
+ continue
59
+ # Look for summary JSON first, then CSVs
60
+ summary_file = subdir / "summary.json"
61
+ if summary_file.exists():
62
+ _parse_task1_summary(summary_file, rows, model_id, display, fmt, strat)
63
+ else:
64
+ # Fall back to per-grid CSVs
65
+ for csv_file in sorted(subdir.glob("*.csv")):
66
+ _parse_task1_csv(csv_file, rows, model_id, display, fmt, strat)
67
+
68
+ if not rows:
69
+ return pd.DataFrame(columns=[
70
+ "model", "display_name", "input_format", "prompt_strategy",
71
+ "grid_size", "k_shot", "accuracy"
72
+ ])
73
+ return pd.DataFrame(rows)
74
+
75
+
76
+ def _parse_task1_summary(path: Path, rows: list, model_id, display, fmt, strat):
77
+ try:
78
+ with open(path) as f:
79
+ data = json.load(f)
80
+ # Expected: {grid_size: {k: accuracy, ...}, ...}
81
+ for grid_key, k_dict in data.items():
82
+ try:
83
+ grid_size = int(str(grid_key).replace("x", "").split("_")[0])
84
+ except ValueError:
85
+ continue
86
+ if isinstance(k_dict, dict):
87
+ for k, acc in k_dict.items():
88
+ rows.append({
89
+ "model": model_id, "display_name": display,
90
+ "input_format": fmt, "prompt_strategy": strat,
91
+ "grid_size": grid_size, "k_shot": int(k),
92
+ "accuracy": float(acc),
93
+ })
94
+ elif isinstance(k_dict, (int, float)):
95
+ rows.append({
96
+ "model": model_id, "display_name": display,
97
+ "input_format": fmt, "prompt_strategy": strat,
98
+ "grid_size": grid_size, "k_shot": 0,
99
+ "accuracy": float(k_dict),
100
+ })
101
+ except Exception:
102
+ pass
103
+
104
+
105
+ def _parse_task1_csv(path: Path, rows: list, model_id, display, fmt, strat):
106
+ try:
107
+ df = pd.read_csv(path)
108
+ # Detect grid_size and k_shot from filename or columns
109
+ grid_size = None
110
+ k_shot = 0
111
+ name = path.stem
112
+ for part in name.split("_"):
113
+ if part.startswith("k") and part[1:].isdigit():
114
+ k_shot = int(part[1:])
115
+ if "x" in part:
116
+ try:
117
+ g = int(part.split("x")[0])
118
+ grid_size = g
119
+ except ValueError:
120
+ pass
121
+
122
+ if "grid_size" in df.columns:
123
+ for gs, gdf in df.groupby("grid_size"):
124
+ acc = _df_accuracy(gdf)
125
+ rows.append({
126
+ "model": model_id, "display_name": display,
127
+ "input_format": fmt, "prompt_strategy": strat,
128
+ "grid_size": int(gs), "k_shot": k_shot,
129
+ "accuracy": acc,
130
+ })
131
+ elif grid_size is not None:
132
+ rows.append({
133
+ "model": model_id, "display_name": display,
134
+ "input_format": fmt, "prompt_strategy": strat,
135
+ "grid_size": grid_size, "k_shot": k_shot,
136
+ "accuracy": _df_accuracy(df),
137
+ })
138
+ except Exception:
139
+ pass
140
+
141
+
142
+ def _df_accuracy(df: pd.DataFrame) -> float:
143
+ for col in ("is_correct", "exact_match", "correct", "accuracy"):
144
+ if col in df.columns:
145
+ return float(df[col].mean())
146
+ return float("nan")
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Task 2 results
151
+ # ---------------------------------------------------------------------------
152
+
153
+ def load_point_reuse_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
154
+ """
155
+ Return DataFrame with columns:
156
+ model, display_name, prompt_strategy, grid_size, question_idx, accuracy
157
+ """
158
+ task = cfg["point_reuse"]
159
+ base = repo_root / task["output_base"]
160
+ rows = []
161
+
162
+ for model_id, model_meta in cfg["models"].items():
163
+ display = model_meta["display_name"]
164
+ for strat, strat_cfg in task["prompt_strategies"].items():
165
+ subdir = (
166
+ base
167
+ / model_id.replace(".", "_").replace("-", "_")
168
+ / f"point_reuse_q3q0_{strat}"
169
+ )
170
+ if not subdir.exists():
171
+ # Also try the pattern used by existing scripts
172
+ subdir = base / model_id / f"proximity_comparison_point_reuse_last_first_same_{strat_cfg['prompt_type']}"
173
+ if not subdir.exists():
174
+ continue
175
+
176
+ csv_files = list(subdir.glob("*.csv"))
177
+ for csv_file in csv_files:
178
+ try:
179
+ df = pd.read_csv(csv_file)
180
+ if "grid_size" not in df.columns:
181
+ continue
182
+ q_col = next(
183
+ (c for c in ("question_idx", "question_index", "q_idx") if c in df.columns),
184
+ None,
185
+ )
186
+ for gs, gdf in df.groupby("grid_size"):
187
+ if q_col:
188
+ for qi, qdf in gdf.groupby(q_col):
189
+ rows.append({
190
+ "model": model_id, "display_name": display,
191
+ "prompt_strategy": strat,
192
+ "grid_size": int(gs),
193
+ "question_idx": int(qi),
194
+ "accuracy": _df_accuracy(qdf),
195
+ })
196
+ else:
197
+ rows.append({
198
+ "model": model_id, "display_name": display,
199
+ "prompt_strategy": strat,
200
+ "grid_size": int(gs),
201
+ "question_idx": -1,
202
+ "accuracy": _df_accuracy(gdf),
203
+ })
204
+ except Exception:
205
+ pass
206
+
207
+ if not rows:
208
+ return pd.DataFrame(columns=[
209
+ "model", "display_name", "prompt_strategy",
210
+ "grid_size", "question_idx", "accuracy"
211
+ ])
212
+ return pd.DataFrame(rows)
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Task 3 results
217
+ # ---------------------------------------------------------------------------
218
+
219
+ def load_compositional_distance_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
220
+ """
221
+ Return DataFrame with columns:
222
+ model, display_name, prompt_strategy, grid_size, question_idx, accuracy, delta
223
+ """
224
+ task = cfg["compositional_distance"]
225
+ base = repo_root / task["output_base"]
226
+ rows = []
227
+
228
+ for model_id, model_meta in cfg["models"].items():
229
+ display = model_meta["display_name"]
230
+ for strat, strat_cfg in task["prompt_strategies"].items():
231
+ tag = f"orthogonal_{task['corner_pattern']}_{strat}"
232
+ subdir = (
233
+ base
234
+ / model_id.replace(".", "_").replace("-", "_")
235
+ / tag
236
+ )
237
+ if not subdir.exists():
238
+ continue
239
+
240
+ # Prefer summary_stats.json
241
+ stats_file = subdir / "summary_stats.json"
242
+ if stats_file.exists():
243
+ try:
244
+ with open(stats_file) as f:
245
+ data = json.load(f)
246
+ _parse_task3_stats(data, rows, model_id, display, strat)
247
+ continue
248
+ except Exception:
249
+ pass
250
+
251
+ # Fall back to results.csv
252
+ for csv_file in sorted(subdir.glob("*.csv")):
253
+ try:
254
+ df = pd.read_csv(csv_file)
255
+ if "grid_size" not in df.columns:
256
+ continue
257
+ q_col = next(
258
+ (c for c in ("question_idx", "question_index") if c in df.columns),
259
+ None,
260
+ )
261
+ for gs, gdf in df.groupby("grid_size"):
262
+ if q_col:
263
+ q_accs = {}
264
+ for qi, qdf in gdf.groupby(q_col):
265
+ acc = _df_accuracy(qdf)
266
+ q_accs[int(qi)] = acc
267
+ rows.append({
268
+ "model": model_id, "display_name": display,
269
+ "prompt_strategy": strat,
270
+ "grid_size": int(gs),
271
+ "question_idx": int(qi),
272
+ "accuracy": acc,
273
+ "delta": float("nan"),
274
+ })
275
+ # Compute delta for Q2 vs avg(Q0, Q1)
276
+ if 0 in q_accs and 1 in q_accs and 2 in q_accs:
277
+ delta = q_accs[2] - (q_accs[0] + q_accs[1]) / 2
278
+ for r in rows:
279
+ if (r["model"] == model_id and
280
+ r["prompt_strategy"] == strat and
281
+ r["grid_size"] == int(gs) and
282
+ r["question_idx"] == 2):
283
+ r["delta"] = round(delta, 4)
284
+ except Exception:
285
+ pass
286
+
287
+ if not rows:
288
+ return pd.DataFrame(columns=[
289
+ "model", "display_name", "prompt_strategy",
290
+ "grid_size", "question_idx", "accuracy", "delta"
291
+ ])
292
+ return pd.DataFrame(rows)
293
+
294
+
295
+ def _parse_task3_stats(data: dict, rows: list, model_id, display, strat):
296
+ """Parse summary_stats.json for task3."""
297
+ try:
298
+ by_q = data.get("accuracy_by_question", data.get("per_question", {}))
299
+ by_gs = data.get("accuracy_by_grid_size", {})
300
+ for gs_key, gs_data in by_gs.items():
301
+ try:
302
+ gs = int(str(gs_key).replace("x", "").split("_")[0])
303
+ except ValueError:
304
+ continue
305
+ if isinstance(gs_data, dict):
306
+ q_accs = {}
307
+ for qi_key, acc in gs_data.items():
308
+ try:
309
+ qi = int(qi_key)
310
+ q_accs[qi] = float(acc)
311
+ rows.append({
312
+ "model": model_id, "display_name": display,
313
+ "prompt_strategy": strat,
314
+ "grid_size": gs, "question_idx": qi,
315
+ "accuracy": float(acc), "delta": float("nan"),
316
+ })
317
+ except (ValueError, TypeError):
318
+ pass
319
+ if 0 in q_accs and 1 in q_accs and 2 in q_accs:
320
+ delta = q_accs[2] - (q_accs[0] + q_accs[1]) / 2
321
+ for r in rows:
322
+ if (r["model"] == model_id and r["prompt_strategy"] == strat
323
+ and r["grid_size"] == gs and r["question_idx"] == 2):
324
+ r["delta"] = round(delta, 4)
325
+ except Exception:
326
+ pass
327
+
328
+
329
+ # ---------------------------------------------------------------------------
330
+ # Leaderboard aggregators
331
+ # ---------------------------------------------------------------------------
332
+
333
+ def maze_navigation_leaderboard(df: pd.DataFrame, k_shot: int = 0) -> pd.DataFrame:
334
+ """
335
+ Pivot Task 1 results into a leaderboard table.
336
+ Rows = models, columns = (format × strategy), values = accuracy at k_shot.
337
+ """
338
+ if df.empty:
339
+ return pd.DataFrame()
340
+ sub = df[df["k_shot"] == k_shot]
341
+ if sub.empty:
342
+ return pd.DataFrame()
343
+ pivot = sub.pivot_table(
344
+ index=["display_name"],
345
+ columns=["input_format", "prompt_strategy"],
346
+ values="accuracy",
347
+ aggfunc="mean",
348
+ )
349
+ pivot.columns = [f"{fmt}_{strat}" for fmt, strat in pivot.columns]
350
+ pivot = pivot.reset_index().rename(columns={"display_name": "Model"})
351
+ return pivot.round(3)
352
+
353
+
354
+ def point_reuse_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
355
+ """
356
+ Task 2 leaderboard: per-model accuracy at Q0 and Q3 across all grid sizes.
357
+ Highlights Q3 vs Q0 consistency.
358
+ """
359
+ if df.empty:
360
+ return pd.DataFrame()
361
+ q0 = df[df["question_idx"] == 0].groupby("display_name")["accuracy"].mean().rename("Q0 acc")
362
+ q3 = df[df["question_idx"] == 3].groupby("display_name")["accuracy"].mean().rename("Q3 acc")
363
+ out = pd.concat([q0, q3], axis=1).reset_index().rename(columns={"display_name": "Model"})
364
+ out["Q3-Q0 diff"] = (out["Q3 acc"] - out["Q0 acc"]).round(3)
365
+ return out.round(3)
366
+
367
+
368
+ def compositional_distance_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
369
+ """
370
+ Task 3 leaderboard: per-model Q0/Q1/Q2 accuracy + delta (Q2 vs avg Q0/Q1).
371
+ """
372
+ if df.empty:
373
+ return pd.DataFrame()
374
+ rows = []
375
+ for model, mdf in df.groupby("display_name"):
376
+ q0 = mdf[mdf["question_idx"] == 0]["accuracy"].mean()
377
+ q1 = mdf[mdf["question_idx"] == 1]["accuracy"].mean()
378
+ q2 = mdf[mdf["question_idx"] == 2]["accuracy"].mean()
379
+ delta = q2 - (q0 + q1) / 2 if not (pd.isna(q0) or pd.isna(q1) or pd.isna(q2)) else float("nan")
380
+ rows.append({
381
+ "Model": model,
382
+ "Q0 (A→M)": round(q0, 3),
383
+ "Q1 (D→M)": round(q1, 3),
384
+ "Q2 (B→C)": round(q2, 3),
385
+ "Δ Q2 vs avg(Q0,Q1)": round(delta, 3),
386
+ })
387
+ return pd.DataFrame(rows)
388
+
389
+
390
+ # ---------------------------------------------------------------------------
391
+ # Full results loader (called by app.py)
392
+ # ---------------------------------------------------------------------------
393
+
394
+ def load_all_results(config_path: str | Path) -> dict[str, pd.DataFrame]:
395
+ """Load results for all three tasks. Returns dict of DataFrames."""
396
+ cfg = load_config(config_path)
397
+ repo_root = Path(config_path).parent.parent.parent # pipeline/configs/.. → llm-maze-solver
398
+ return {
399
+ "maze_navigation": load_maze_navigation_results(cfg, repo_root),
400
+ "point_reuse": load_point_reuse_results(cfg, repo_root),
401
+ "compositional_distance": load_compositional_distance_results(cfg, repo_root),
402
+ }
pipeline/task_builder.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ task_builder.py
3
+ ---------------
4
+ Translates experiments.yaml into concrete shell commands (direct or sbatch).
5
+
6
+ Each public function returns a list of ExperimentJob dataclasses, one per
7
+ (model × format × prompt_strategy × grid_sizes) combination. The caller
8
+ decides whether to run them directly or wrap them in sbatch.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import itertools
15
+ from dataclasses import dataclass, field
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import yaml
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Data structures
23
+ # ---------------------------------------------------------------------------
24
+
25
+ @dataclass
26
+ class ExperimentJob:
27
+ """A single runnable experiment unit."""
28
+ task_id: str # e.g. "maze_navigation"
29
+ model: str # e.g. "gemini-2.5-flash"
30
+ label: str # human-readable label for this job
31
+ working_dir: Path # where to cd before running
32
+ python_cmd: list[str] # [python, script.py, --arg, value, ...]
33
+ api_key_env: str # env-var name that must be set
34
+ output_dir: Path # where results land
35
+ sbatch_cfg: dict # mem, time, cpus, partition, log_dir
36
+ grid_sizes: list[int] # for display / filtering
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Config loader
41
+ # ---------------------------------------------------------------------------
42
+
43
+ def load_config(config_path: str | Path) -> dict:
44
+ with open(config_path) as f:
45
+ return yaml.safe_load(f)
46
+
47
+
48
+ def _repo_root(config_path: Path) -> Path:
49
+ """pipeline/configs/experiments.yaml → llm-maze-solver/"""
50
+ return config_path.parent.parent.parent
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Internal helpers
55
+ # ---------------------------------------------------------------------------
56
+
57
+ def _merge_sbatch(defaults: dict, override: dict) -> dict:
58
+ merged = dict(defaults)
59
+ merged.update(override)
60
+ return merged
61
+
62
+
63
+ def _grid_str(grid_sizes: list[int]) -> str:
64
+ return ",".join(str(g) for g in grid_sizes)
65
+
66
+
67
+ def _output_subdir(base: str, model: str, tag: str) -> str:
68
+ """Produce a deterministic output subdirectory path."""
69
+ return f"{base}/{model.replace('.', '_').replace('-', '_')}/{tag}"
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Maze Navigation
74
+ # ---------------------------------------------------------------------------
75
+
76
+ def build_maze_navigation_jobs(
77
+ cfg: dict,
78
+ models: list[str] | None = None,
79
+ grid_sizes: list[int] | None = None,
80
+ input_formats: list[str] | None = None,
81
+ prompt_strategies: list[str] | None = None,
82
+ config_path: Path = None,
83
+ ) -> list[ExperimentJob]:
84
+ """Build jobs for Maze Navigation (planning, k-shot)."""
85
+ task = cfg["maze_navigation"]
86
+ defaults = cfg["defaults"]
87
+ model_cfg = cfg["models"]
88
+
89
+ selected_models = models or list(model_cfg.keys())
90
+ selected_formats = input_formats or task["input_formats"]
91
+ selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
92
+ selected_grids = grid_sizes or task["grid_sizes"]
93
+
94
+ repo = _repo_root(config_path) if config_path else Path(".")
95
+ script = repo / task["script"]
96
+ wdir = repo / task["working_dir"]
97
+
98
+ jobs: list[ExperimentJob] = []
99
+
100
+ for model, fmt, strat in itertools.product(
101
+ selected_models, selected_formats, selected_strategies
102
+ ):
103
+ if model not in model_cfg:
104
+ continue
105
+ strat_cfg = task["prompt_strategies"][strat]
106
+ tag = f"{fmt}_input_{strat}"
107
+ out_dir = repo / _output_subdir(task["output_base"], model, tag)
108
+
109
+ cmd = [
110
+ "python", str(script),
111
+ "--model_name", model,
112
+ "--input_format", fmt,
113
+ "--k_shots", task["k_shots"],
114
+ "--n_test_mazes", str(cfg["defaults"]["n_test_mazes"]),
115
+ "--test_grid_sizes", _grid_str(selected_grids),
116
+ "--maze_type", task["maze_type"],
117
+ "--seed", str(defaults["seed"]),
118
+ "--output_dir", str(out_dir),
119
+ ]
120
+ for flag in strat_cfg["flags"]:
121
+ cmd.append(flag)
122
+ if task.get("visualize"):
123
+ cmd.append("--visualize")
124
+
125
+ jobs.append(ExperimentJob(
126
+ task_id="maze_navigation",
127
+ model=model,
128
+ label=f"Maze Navigation | {model} | {fmt} | {strat}",
129
+ working_dir=wdir,
130
+ python_cmd=cmd,
131
+ api_key_env=model_cfg[model]["api_key_env"],
132
+ output_dir=out_dir,
133
+ sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
134
+ grid_sizes=selected_grids,
135
+ ))
136
+
137
+ return jobs
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Sequential Reasoning with Point Reuse (Q3 = Q0)
142
+ # ---------------------------------------------------------------------------
143
+
144
+ def build_point_reuse_jobs(
145
+ cfg: dict,
146
+ models: list[str] | None = None,
147
+ grid_sizes: list[int] | None = None,
148
+ prompt_strategies: list[str] | None = None,
149
+ config_path: Path = None,
150
+ ) -> list[ExperimentJob]:
151
+ """Build jobs for Sequential Reasoning with Point Reuse (Q3=Q0)."""
152
+ task = cfg["point_reuse"]
153
+ defaults = cfg["defaults"]
154
+ model_cfg = cfg["models"]
155
+
156
+ selected_models = models or list(model_cfg.keys())
157
+ selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
158
+ selected_grids = grid_sizes or task["grid_sizes"]
159
+
160
+ repo = _repo_root(config_path) if config_path else Path(".")
161
+ script = repo / task["script"]
162
+ wdir = repo / task["working_dir"]
163
+
164
+ jobs: list[ExperimentJob] = []
165
+
166
+ for model, strat in itertools.product(selected_models, selected_strategies):
167
+ if model not in model_cfg:
168
+ continue
169
+ strat_cfg = task["prompt_strategies"][strat]
170
+ tag = f"point_reuse_q3q0_{strat}"
171
+ out_dir = repo / _output_subdir(task["output_base"], model, tag)
172
+
173
+ cmd = [
174
+ "python", str(script),
175
+ "--model_name", model,
176
+ "--input_format", task["input_format"],
177
+ "--strategy", task["strategy"],
178
+ "--reuse_pattern", task["reuse_pattern"],
179
+ "--prompt_type", strat_cfg["prompt_type"],
180
+ "--n_questions_per_maze", str(task["n_questions_per_maze"]),
181
+ "--n_test_mazes", str(defaults["n_test_mazes"]),
182
+ "--test_grid_sizes", _grid_str(selected_grids),
183
+ "--output_dir", str(out_dir),
184
+ ]
185
+ if task.get("sequential_questions"):
186
+ cmd.append("--sequential_questions")
187
+ if task.get("visualize"):
188
+ cmd.append("--visualize")
189
+ if task.get("save_details"):
190
+ cmd.append("--save_details")
191
+
192
+ jobs.append(ExperimentJob(
193
+ task_id="point_reuse",
194
+ model=model,
195
+ label=f"Point Reuse | {model} | {strat}",
196
+ working_dir=wdir,
197
+ python_cmd=cmd,
198
+ api_key_env=model_cfg[model]["api_key_env"],
199
+ output_dir=out_dir,
200
+ sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
201
+ grid_sizes=selected_grids,
202
+ ))
203
+
204
+ return jobs
205
+
206
+
207
+ # ---------------------------------------------------------------------------
208
+ # Compositional Distance Comparison
209
+ # ---------------------------------------------------------------------------
210
+
211
+ def build_compositional_distance_jobs(
212
+ cfg: dict,
213
+ models: list[str] | None = None,
214
+ grid_sizes: list[int] | None = None,
215
+ prompt_strategies: list[str] | None = None,
216
+ config_path: Path = None,
217
+ ) -> list[ExperimentJob]:
218
+ """Build jobs for Compositional Distance Comparison (corners-to-center)."""
219
+ task = cfg["compositional_distance"]
220
+ defaults = cfg["defaults"]
221
+ model_cfg = cfg["models"]
222
+
223
+ selected_models = models or list(model_cfg.keys())
224
+ selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
225
+ selected_grids = grid_sizes or task["grid_sizes"]
226
+
227
+ repo = _repo_root(config_path) if config_path else Path(".")
228
+ script = repo / task["script"]
229
+ wdir = repo / task["working_dir"]
230
+
231
+ jobs: list[ExperimentJob] = []
232
+
233
+ for model, strat in itertools.product(selected_models, selected_strategies):
234
+ if model not in model_cfg:
235
+ continue
236
+ strat_cfg = task["prompt_strategies"][strat]
237
+ tag = f"orthogonal_{task['corner_pattern']}_{strat}"
238
+ out_dir = repo / _output_subdir(task["output_base"], model, tag)
239
+
240
+ cmd = [
241
+ "python", str(script),
242
+ "--model_name", model,
243
+ "--input_format", task["input_format"],
244
+ "--strategy", task["strategy"],
245
+ "--corner_pattern", task["corner_pattern"],
246
+ "--prompt_type", strat_cfg["prompt_type"],
247
+ "--n_questions_per_maze", str(task["n_questions_per_maze"]),
248
+ "--n_test_mazes", str(defaults["n_test_mazes"]),
249
+ "--test_grid_sizes", _grid_str(selected_grids),
250
+ "--output_dir", str(out_dir),
251
+ ]
252
+ if task.get("visualize"):
253
+ cmd.append("--visualize")
254
+ if task.get("save_details"):
255
+ cmd.append("--save_details")
256
+
257
+ jobs.append(ExperimentJob(
258
+ task_id="compositional_distance",
259
+ model=model,
260
+ label=f"Compositional Distance | {model} | {strat}",
261
+ working_dir=wdir,
262
+ python_cmd=cmd,
263
+ api_key_env=model_cfg[model]["api_key_env"],
264
+ output_dir=out_dir,
265
+ sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
266
+ grid_sizes=selected_grids,
267
+ ))
268
+
269
+ return jobs
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Unified builder
274
+ # ---------------------------------------------------------------------------
275
+
276
+ def build_all_jobs(
277
+ cfg: dict,
278
+ tasks: list[str] | None = None,
279
+ models: list[str] | None = None,
280
+ grid_sizes: list[int] | None = None,
281
+ input_formats: list[str] | None = None,
282
+ prompt_strategies: list[str] | None = None,
283
+ config_path: Path = None,
284
+ ) -> list[ExperimentJob]:
285
+ """Build jobs for all requested tasks."""
286
+ selected_tasks = tasks or ["maze_navigation", "point_reuse", "compositional_distance"]
287
+ jobs: list[ExperimentJob] = []
288
+ kw = dict(
289
+ models=models,
290
+ grid_sizes=grid_sizes,
291
+ prompt_strategies=prompt_strategies,
292
+ config_path=config_path,
293
+ )
294
+ if "maze_navigation" in selected_tasks:
295
+ jobs += build_maze_navigation_jobs(cfg, input_formats=input_formats, **kw)
296
+ if "point_reuse" in selected_tasks:
297
+ jobs += build_point_reuse_jobs(cfg, **kw)
298
+ if "compositional_distance" in selected_tasks:
299
+ jobs += build_compositional_distance_jobs(cfg, **kw)
300
+ return jobs
301
+
302
+
303
+ # ---------------------------------------------------------------------------
304
+ # sbatch script generator
305
+ # ---------------------------------------------------------------------------
306
+
307
+ def make_sbatch_script(job: ExperimentJob, log_dir: Path) -> str:
308
+ """Return the text of an sbatch submission script for a job."""
309
+ s = job.sbatch_cfg
310
+ log_dir.mkdir(parents=True, exist_ok=True)
311
+ safe_label = job.label.replace(" ", "_").replace("|", "").replace("/", "_")
312
+
313
+ lines = [
314
+ "#!/bin/bash",
315
+ f"#SBATCH -c {s.get('cpus', 2)}",
316
+ f"#SBATCH -t {s.get('time', '10:00:00')}",
317
+ f"#SBATCH -p {s.get('partition', 'short')}",
318
+ f"#SBATCH --mem={s.get('mem', '8G')}",
319
+ f"#SBATCH -o {log_dir}/{safe_label}_%j.out",
320
+ f"#SBATCH -e {log_dir}/{safe_label}_%j.err",
321
+ "",
322
+ f"# {job.label}",
323
+ f"export {job.api_key_env}=${{{job.api_key_env}}}",
324
+ "",
325
+ f"cd {job.working_dir}",
326
+ " \\\n ".join(job.python_cmd),
327
+ ]
328
+ return "\n".join(lines) + "\n"
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SpatialBench — pipeline dependencies
2
+ # Install with: pip install -r requirements.txt
3
+
4
+ # Gradio UI (HuggingFace Space entrypoint)
5
+ gradio>=4.20.0
6
+
7
+ # Plotting
8
+ plotly>=5.18.0
9
+
10
+ # Data
11
+ pandas>=2.0.0
12
+ numpy>=1.24.0
13
+
14
+ # Config parsing
15
+ PyYAML>=6.0
16
+
17
+ # LLM API clients
18
+ openai>=1.14.0
19
+ anthropic>=0.25.0
20
+ google-generativeai>=0.5.0
21
+
22
+ # (DeepSeek uses the OpenAI-compatible client — no extra package needed)
23
+
24
+ # Sentence embeddings for reasoning quality analysis
25
+ sentence-transformers>=2.6.0
26
+
27
+ # ROUGE for reasoning quality analysis
28
+ rouge-score>=0.1.2
29
+
30
+ # Environment variable loading
31
+ python-dotenv>=1.0.0
run_experiments.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ run_experiments.py
4
+ ------------------
5
+ CLI orchestrator for SpatialBench experiments.
6
+
7
+ Run on the cluster with SLURM:
8
+ python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm
9
+
10
+ Run directly (uses API keys, no SLURM required):
11
+ python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct
12
+
13
+ Dry-run (print commands without executing):
14
+ python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run
15
+
16
+ Filter experiments:
17
+ python run_experiments.py --tasks maze_navigation \\
18
+ --models gemini-2.5-flash claude-haiku-4-5 \\
19
+ --grid-sizes 5 6 7 \\
20
+ --formats raw \\
21
+ --strategies cot reasoning
22
+
23
+ Show status of running SLURM jobs:
24
+ python run_experiments.py --status
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import argparse
30
+ import json
31
+ import os
32
+ import subprocess
33
+ import sys
34
+ import tempfile
35
+ import time
36
+ from pathlib import Path
37
+
38
+ # Load .env if present (before importing pipeline modules)
39
+ _env_file = Path(__file__).parent / ".env"
40
+ if _env_file.exists():
41
+ with open(_env_file) as _f:
42
+ for _line in _f:
43
+ _line = _line.strip()
44
+ if _line and not _line.startswith("#") and "=" in _line:
45
+ _k, _v = _line.split("=", 1)
46
+ os.environ.setdefault(_k.strip(), _v.strip())
47
+
48
+ from pipeline.task_builder import (
49
+ load_config, build_all_jobs, make_sbatch_script, ExperimentJob,
50
+ )
51
+ from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct
52
+
53
+ CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
54
+ REPO_ROOT = CONFIG_PATH.parent.parent.parent # llm-maze-solver/
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Helpers
59
+ # ---------------------------------------------------------------------------
60
+
61
+ def _check_api_key(job: ExperimentJob) -> bool:
62
+ val = os.environ.get(job.api_key_env, "")
63
+ if not val:
64
+ print(f" [WARN] {job.api_key_env} not set — skipping: {job.label}")
65
+ return False
66
+ return True
67
+
68
+
69
+ def _print_job(job: ExperimentJob) -> None:
70
+ print(f"\n {job.label}")
71
+ print(f" cmd : {' '.join(job.python_cmd[:4])} ...")
72
+ print(f" wdir: {job.working_dir}")
73
+ print(f" out : {job.output_dir}")
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Run modes
78
+ # ---------------------------------------------------------------------------
79
+
80
+ def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
81
+ log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs"
82
+ log_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ for job in jobs:
85
+ if not _check_api_key(job):
86
+ continue
87
+ script_text = make_sbatch_script(job, log_dir)
88
+ if dry_run:
89
+ _print_job(job)
90
+ print(" --- sbatch script ---")
91
+ print(script_text)
92
+ continue
93
+
94
+ with tempfile.NamedTemporaryFile(
95
+ mode="w", suffix=".sh", prefix="spatialbench_",
96
+ dir=log_dir, delete=False
97
+ ) as tmp:
98
+ tmp.write(script_text)
99
+ script_path = tmp.name
100
+
101
+ job_id = submit_sbatch(script_text, script_path)
102
+ if job_id:
103
+ monitor.add(
104
+ job_id=job_id,
105
+ label=job.label,
106
+ task_id=job.task_id,
107
+ model=job.model,
108
+ output_dir=str(job.output_dir),
109
+ log_out=str(log_dir / f"{job_id}.out"),
110
+ log_err=str(log_dir / f"{job_id}.err"),
111
+ )
112
+ print(f" Submitted {job.label} → SLURM job {job_id}")
113
+ else:
114
+ print(f" [ERROR] Failed to submit: {job.label}")
115
+
116
+
117
+ def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
118
+ for job in jobs:
119
+ if not _check_api_key(job):
120
+ continue
121
+ if dry_run:
122
+ _print_job(job)
123
+ print(f" cmd: {' '.join(job.python_cmd)}\n")
124
+ continue
125
+
126
+ env_patch = {job.api_key_env: os.environ[job.api_key_env]}
127
+ job.output_dir.mkdir(parents=True, exist_ok=True)
128
+
129
+ print(f" Starting: {job.label}")
130
+ proc = submit_direct(
131
+ cmd=job.python_cmd,
132
+ working_dir=str(job.working_dir),
133
+ env=env_patch,
134
+ )
135
+ monitor.add_direct(
136
+ proc=proc,
137
+ label=job.label,
138
+ task_id=job.task_id,
139
+ model=job.model,
140
+ output_dir=str(job.output_dir),
141
+ )
142
+ # Small gap to avoid hammering APIs simultaneously
143
+ time.sleep(2)
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Status display
148
+ # ---------------------------------------------------------------------------
149
+
150
+ def show_status(monitor: JobMonitor) -> None:
151
+ monitor.refresh()
152
+ summary = monitor.summary()
153
+ print(f"\nTotal jobs: {summary['total']}")
154
+ for status, count in summary["counts"].items():
155
+ print(f" {status:12s}: {count}")
156
+ print()
157
+ for r in summary["records"]:
158
+ print(f" [{r['status']:9s}] {r['label']:<60s} elapsed: {r['elapsed']}")
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Argument parsing
163
+ # ---------------------------------------------------------------------------
164
+
165
+ def parse_args() -> argparse.Namespace:
166
+ parser = argparse.ArgumentParser(
167
+ description="SpatialBench experiment orchestrator",
168
+ formatter_class=argparse.RawDescriptionHelpFormatter,
169
+ epilog=__doc__,
170
+ )
171
+
172
+ parser.add_argument(
173
+ "--tasks", nargs="+",
174
+ default=["maze_navigation", "point_reuse", "compositional_distance"],
175
+ choices=["maze_navigation", "point_reuse", "compositional_distance"],
176
+ help="Which tasks to run (default: all three)",
177
+ )
178
+ parser.add_argument(
179
+ "--models", nargs="+", default=None,
180
+ help="Model IDs to run (default: all models in config)",
181
+ )
182
+ parser.add_argument(
183
+ "--grid-sizes", nargs="+", type=int, default=None,
184
+ dest="grid_sizes",
185
+ help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)",
186
+ )
187
+ parser.add_argument(
188
+ "--formats", nargs="+", default=None,
189
+ choices=["raw", "visual"],
190
+ help="Input formats for Task 1 (default: both raw and visual)",
191
+ )
192
+ parser.add_argument(
193
+ "--strategies", nargs="+", default=None,
194
+ choices=["base", "cot", "reasoning"],
195
+ help="Prompt strategies (default: all)",
196
+ )
197
+ parser.add_argument(
198
+ "--mode", default="slurm", choices=["slurm", "direct"],
199
+ help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)",
200
+ )
201
+ parser.add_argument(
202
+ "--dry-run", action="store_true",
203
+ help="Print commands without executing them",
204
+ )
205
+ parser.add_argument(
206
+ "--no-wait", action="store_true",
207
+ help="Return immediately after submission (don't poll for completion)",
208
+ )
209
+ parser.add_argument(
210
+ "--status", action="store_true",
211
+ help="Query and display SLURM job status (requires --job-ids or a running monitor)",
212
+ )
213
+ parser.add_argument(
214
+ "--job-ids", nargs="+", default=None,
215
+ help="SLURM job IDs to check status for (used with --status)",
216
+ )
217
+ parser.add_argument(
218
+ "--config", default=str(CONFIG_PATH),
219
+ help=f"Path to experiments.yaml (default: {CONFIG_PATH})",
220
+ )
221
+ parser.add_argument(
222
+ "--poll-interval", type=int, default=60,
223
+ dest="poll_interval",
224
+ help="Seconds between SLURM status polls when waiting (default: 60)",
225
+ )
226
+
227
+ return parser.parse_args()
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # Main
232
+ # ---------------------------------------------------------------------------
233
+
234
+ def main() -> None:
235
+ args = parse_args()
236
+ cfg = load_config(args.config)
237
+
238
+ # Status-only mode
239
+ if args.status:
240
+ monitor = JobMonitor(mode="slurm")
241
+ if args.job_ids:
242
+ for jid in args.job_ids:
243
+ monitor.add(job_id=jid, label=jid, task_id="?", model="?")
244
+ show_status(monitor)
245
+ return
246
+
247
+ # Build jobs
248
+ jobs = build_all_jobs(
249
+ cfg=cfg,
250
+ tasks=args.tasks,
251
+ models=args.models,
252
+ grid_sizes=args.grid_sizes,
253
+ input_formats=args.formats,
254
+ prompt_strategies=args.strategies,
255
+ config_path=Path(args.config),
256
+ )
257
+
258
+ if not jobs:
259
+ print("No jobs matched the requested filters.")
260
+ return
261
+
262
+ print(f"\nSpatialBench — {len(jobs)} job(s) to run")
263
+ print(f" mode : {args.mode}")
264
+ print(f" tasks : {args.tasks}")
265
+ print(f" models : {args.models or 'all'}")
266
+ print(f" grids : {args.grid_sizes or 'per-task default'}")
267
+ print(f" formats : {args.formats or 'per-task default'}")
268
+ print(f" strategies: {args.strategies or 'all'}")
269
+ print(f" dry-run : {args.dry_run}")
270
+ print()
271
+
272
+ monitor = JobMonitor(mode=args.mode)
273
+
274
+ if args.mode == "slurm":
275
+ run_slurm(jobs, monitor, dry_run=args.dry_run)
276
+ else:
277
+ run_direct(jobs, monitor, dry_run=args.dry_run)
278
+
279
+ if args.dry_run or args.no_wait:
280
+ if not args.dry_run:
281
+ print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.")
282
+ return
283
+
284
+ # Wait for completion
285
+ print("\nWaiting for jobs to complete...")
286
+
287
+ def _progress(summary: dict) -> None:
288
+ counts = summary["counts"]
289
+ parts = [f"{s}: {n}" for s, n in counts.items()]
290
+ print(f" [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}")
291
+
292
+ monitor.wait_all(poll_interval=args.poll_interval, callback=_progress)
293
+
294
+ # Final summary
295
+ summary = monitor.summary()
296
+ print(f"\nDone. {summary['counts'].get('completed', 0)} completed, "
297
+ f"{summary['counts'].get('failed', 0)} failed.")
298
+
299
+ failed = [r for r in summary["records"] if r["status"] == "failed"]
300
+ if failed:
301
+ print("\nFailed jobs:")
302
+ for r in failed:
303
+ print(f" {r['label']} (job_id={r['job_id']})")
304
+
305
+
306
+ if __name__ == "__main__":
307
+ main()