Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +19 -6
- app.py +688 -0
- configs/experiments.yaml +171 -0
- pipeline/__init__.py +1 -0
- pipeline/job_monitor.py +314 -0
- pipeline/results_loader.py +402 -0
- pipeline/task_builder.py +328 -0
- requirements.txt +31 -0
- run_experiments.py +307 -0
README.md
CHANGED
|
@@ -1,12 +1,25 @@
|
|
| 1 |
---
|
| 2 |
title: SpatialBench
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: SpatialBench
|
| 3 |
+
emoji: 🧩
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.0"
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
short_description: Do LLMs Build Spatial World Models? Evidence from Maze Tasks
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SpatialBench
|
| 14 |
+
|
| 15 |
+
Evaluation platform for **"Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks"** (ICLR 2026 Workshop).
|
| 16 |
+
|
| 17 |
+
Three tasks probe whether LLMs construct internal spatial representations:
|
| 18 |
+
|
| 19 |
+
| Task | Type | Description |
|
| 20 |
+
|------|------|-------------|
|
| 21 |
+
| **Maze Navigation** | Planning | Find shortest path from start to goal |
|
| 22 |
+
| **Sequential Point Reuse** | Reasoning | Q3 = Q0 — do models reuse earlier computation? |
|
| 23 |
+
| **Compositional Distance** | Reasoning | Compose corner→center distances for Q2 |
|
| 24 |
+
|
| 25 |
+
Models evaluated: Gemini 2.5 Flash, GPT-5 Mini, Claude Haiku 4.5, DeepSeek Chat.
|
app.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py — SpatialBench Gradio application
|
| 3 |
+
-----------------------------------------
|
| 4 |
+
Entrypoint for the HuggingFace Space "SpatialBench".
|
| 5 |
+
|
| 6 |
+
Two tabs:
|
| 7 |
+
1. Leaderboard — visualize pre-computed results from all three tasks
|
| 8 |
+
2. Run — launch experiments directly via API keys (no SLURM needed)
|
| 9 |
+
(on HF Space, set API keys as Space Secrets)
|
| 10 |
+
|
| 11 |
+
To run locally:
|
| 12 |
+
cd pipeline/
|
| 13 |
+
python app.py
|
| 14 |
+
|
| 15 |
+
To deploy on HuggingFace Spaces:
|
| 16 |
+
- Set Space Secrets: GEMINI_API_KEY, OPENAI_API_KEY, ANTHROPIC_API_KEY, DEEPSEEK_API_KEY
|
| 17 |
+
- The Space entrypoint is this file (app.py)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import threading
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import gradio as gr
|
| 29 |
+
import pandas as pd
|
| 30 |
+
import plotly.express as px
|
| 31 |
+
import plotly.graph_objects as go
|
| 32 |
+
|
| 33 |
+
# Load .env if running locally
|
| 34 |
+
_env = Path(__file__).parent / ".env"
|
| 35 |
+
if _env.exists():
|
| 36 |
+
with open(_env) as _f:
|
| 37 |
+
for _line in _f:
|
| 38 |
+
_line = _line.strip()
|
| 39 |
+
if _line and not _line.startswith("#") and "=" in _line:
|
| 40 |
+
_k, _v = _line.split("=", 1)
|
| 41 |
+
os.environ.setdefault(_k.strip(), _v.strip())
|
| 42 |
+
|
| 43 |
+
# Add repo root to path so pipeline imports work
|
| 44 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 45 |
+
|
| 46 |
+
from pipeline.task_builder import load_config, build_all_jobs
|
| 47 |
+
from pipeline.job_monitor import JobMonitor, submit_direct
|
| 48 |
+
from pipeline.results_loader import (
|
| 49 |
+
load_all_results,
|
| 50 |
+
maze_navigation_leaderboard,
|
| 51 |
+
point_reuse_leaderboard,
|
| 52 |
+
compositional_distance_leaderboard,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Paths
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
|
| 59 |
+
CFG = load_config(CONFIG_PATH)
|
| 60 |
+
MODEL_CHOICES = list(CFG["models"].keys())
|
| 61 |
+
MODEL_DISPLAY = {k: v["display_name"] for k, v in CFG["models"].items()}
|
| 62 |
+
|
| 63 |
+
# Global job monitor (direct mode only — HF Space has no SLURM)
|
| 64 |
+
_monitor = JobMonitor(mode="direct")
|
| 65 |
+
_monitor_lock = threading.Lock()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Leaderboard helpers
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
def _load_results():
|
| 73 |
+
try:
|
| 74 |
+
return load_all_results(CONFIG_PATH)
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return {"maze_navigation": pd.DataFrame(), "point_reuse": pd.DataFrame(), "compositional_distance": pd.DataFrame()}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _make_empty_fig(msg: str) -> go.Figure:
|
| 80 |
+
fig = go.Figure()
|
| 81 |
+
fig.add_annotation(text=msg, x=0.5, y=0.5, showarrow=False,
|
| 82 |
+
font=dict(size=16), xref="paper", yref="paper")
|
| 83 |
+
fig.update_layout(xaxis_visible=False, yaxis_visible=False,
|
| 84 |
+
height=300, paper_bgcolor="rgba(0,0,0,0)",
|
| 85 |
+
plot_bgcolor="rgba(0,0,0,0)")
|
| 86 |
+
return fig
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ── Task 1 plots ────────────────────────────────────────────────────────────
|
| 90 |
+
|
| 91 |
+
def plot_task1_accuracy(k_shot: int, input_format: str) -> tuple[go.Figure, pd.DataFrame]:
|
| 92 |
+
results = _load_results()
|
| 93 |
+
df = results["maze_navigation"]
|
| 94 |
+
if df.empty:
|
| 95 |
+
return _make_empty_fig("No Task 1 results found.\nRun experiments first."), pd.DataFrame()
|
| 96 |
+
|
| 97 |
+
sub = df[(df["k_shot"] == k_shot) & (df["input_format"] == input_format)]
|
| 98 |
+
if sub.empty:
|
| 99 |
+
return _make_empty_fig(f"No results for k={k_shot}, format={input_format}"), pd.DataFrame()
|
| 100 |
+
|
| 101 |
+
fig = px.line(
|
| 102 |
+
sub, x="grid_size", y="accuracy",
|
| 103 |
+
color="display_name", line_dash="prompt_strategy",
|
| 104 |
+
markers=True,
|
| 105 |
+
labels={"grid_size": "Grid Size (n×n)", "accuracy": "Accuracy",
|
| 106 |
+
"display_name": "Model", "prompt_strategy": "Strategy"},
|
| 107 |
+
title=f"Task 1 — Maze Navigation ({input_format} format, {k_shot}-shot)",
|
| 108 |
+
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 109 |
+
)
|
| 110 |
+
fig.update_layout(
|
| 111 |
+
yaxis_range=[0, 1],
|
| 112 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 113 |
+
height=420,
|
| 114 |
+
)
|
| 115 |
+
lb = maze_navigation_leaderboard(df, k_shot=k_shot)
|
| 116 |
+
return fig, lb
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def plot_task1_format_comparison() -> go.Figure:
|
| 120 |
+
results = _load_results()
|
| 121 |
+
df = results["maze_navigation"]
|
| 122 |
+
if df.empty:
|
| 123 |
+
return _make_empty_fig("No Task 1 results found.")
|
| 124 |
+
|
| 125 |
+
# Average over grid sizes, compare raw vs visual at k=0 with CoT
|
| 126 |
+
sub = df[(df["k_shot"] == 0) & (df["prompt_strategy"] == "cot")]
|
| 127 |
+
if sub.empty:
|
| 128 |
+
sub = df[df["k_shot"] == 0]
|
| 129 |
+
agg = sub.groupby(["display_name", "input_format"])["accuracy"].mean().reset_index()
|
| 130 |
+
|
| 131 |
+
fig = px.bar(
|
| 132 |
+
agg, x="display_name", y="accuracy", color="input_format",
|
| 133 |
+
barmode="group",
|
| 134 |
+
labels={"display_name": "Model", "accuracy": "Mean Accuracy",
|
| 135 |
+
"input_format": "Input Format"},
|
| 136 |
+
title="Task 1 — Raw vs Visual Format (0-shot, CoT, averaged over grid sizes)",
|
| 137 |
+
color_discrete_map={"raw": "#2196F3", "visual": "#FF9800"},
|
| 138 |
+
)
|
| 139 |
+
fig.update_layout(yaxis_range=[0, 1], height=380)
|
| 140 |
+
return fig
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ── Task 2 plots ────────────────────────────────────────────────────────────
|
| 144 |
+
|
| 145 |
+
def plot_task2_q0_q3(grid_size: int) -> tuple[go.Figure, pd.DataFrame]:
|
| 146 |
+
results = _load_results()
|
| 147 |
+
df = results["point_reuse"]
|
| 148 |
+
if df.empty:
|
| 149 |
+
return _make_empty_fig("No Task 2 results found.\nRun experiments first."), pd.DataFrame()
|
| 150 |
+
|
| 151 |
+
sub = df[df["grid_size"] == grid_size]
|
| 152 |
+
if sub.empty:
|
| 153 |
+
return _make_empty_fig(f"No Task 2 results for {grid_size}×{grid_size}"), pd.DataFrame()
|
| 154 |
+
|
| 155 |
+
q0 = sub[sub["question_idx"] == 0].groupby("display_name")["accuracy"].mean().rename("Q0")
|
| 156 |
+
q3 = sub[sub["question_idx"] == 3].groupby("display_name")["accuracy"].mean().rename("Q3")
|
| 157 |
+
plot_df = pd.concat([q0, q3], axis=1).reset_index()
|
| 158 |
+
plot_df_melt = plot_df.melt(id_vars="display_name", var_name="Question", value_name="Accuracy")
|
| 159 |
+
|
| 160 |
+
fig = px.bar(
|
| 161 |
+
plot_df_melt, x="display_name", y="Accuracy", color="Question",
|
| 162 |
+
barmode="group",
|
| 163 |
+
labels={"display_name": "Model"},
|
| 164 |
+
title=f"Task 2 — Q0 vs Q3 Accuracy ({grid_size}×{grid_size} maze)\n"
|
| 165 |
+
"Q3 = Q0 (same question repeated — tests information reuse)",
|
| 166 |
+
color_discrete_map={"Q0": "#4CAF50", "Q3": "#F44336"},
|
| 167 |
+
)
|
| 168 |
+
fig.update_layout(yaxis_range=[0, 1], height=400)
|
| 169 |
+
lb = point_reuse_leaderboard(df)
|
| 170 |
+
return fig, lb
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def plot_task2_by_grid() -> go.Figure:
|
| 174 |
+
results = _load_results()
|
| 175 |
+
df = results["point_reuse"]
|
| 176 |
+
if df.empty:
|
| 177 |
+
return _make_empty_fig("No Task 2 results found.")
|
| 178 |
+
|
| 179 |
+
q3 = df[df["question_idx"] == 3].groupby(
|
| 180 |
+
["display_name", "grid_size"])["accuracy"].mean().reset_index()
|
| 181 |
+
|
| 182 |
+
fig = px.line(
|
| 183 |
+
q3, x="grid_size", y="accuracy", color="display_name",
|
| 184 |
+
markers=True,
|
| 185 |
+
labels={"grid_size": "Grid Size", "accuracy": "Q3 Accuracy",
|
| 186 |
+
"display_name": "Model"},
|
| 187 |
+
title="Task 2 — Q3 Accuracy by Grid Size (Q3 = Q0 repeated)",
|
| 188 |
+
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 189 |
+
)
|
| 190 |
+
fig.update_layout(yaxis_range=[0, 1], height=380)
|
| 191 |
+
return fig
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ── Task 3 plots ────────────────────────────────────────────────────────────
|
| 195 |
+
|
| 196 |
+
def plot_task3_compositional() -> tuple[go.Figure, pd.DataFrame]:
|
| 197 |
+
results = _load_results()
|
| 198 |
+
df = results["compositional_distance"]
|
| 199 |
+
if df.empty:
|
| 200 |
+
return _make_empty_fig("No Task 3 results found.\nRun experiments first."), pd.DataFrame()
|
| 201 |
+
|
| 202 |
+
agg = df.groupby(["display_name", "question_idx"])["accuracy"].mean().reset_index()
|
| 203 |
+
q_labels = {0: "Q0: A→M", 1: "Q1: D→M", 2: "Q2: B→C (compositional)"}
|
| 204 |
+
agg["Question"] = agg["question_idx"].map(q_labels)
|
| 205 |
+
|
| 206 |
+
fig = px.bar(
|
| 207 |
+
agg, x="display_name", y="accuracy", color="Question",
|
| 208 |
+
barmode="group",
|
| 209 |
+
labels={"display_name": "Model", "accuracy": "Accuracy"},
|
| 210 |
+
title="Task 3 — Compositional Distance Comparison\n"
|
| 211 |
+
"Q2 can be composed from Q0+Q1 (corner→center distances)",
|
| 212 |
+
color_discrete_map={
|
| 213 |
+
"Q0: A→M": "#2196F3",
|
| 214 |
+
"Q1: D→M": "#9C27B0",
|
| 215 |
+
"Q2: B→C (compositional)": "#FF5722",
|
| 216 |
+
},
|
| 217 |
+
)
|
| 218 |
+
fig.update_layout(yaxis_range=[0, 1], height=420)
|
| 219 |
+
lb = compositional_distance_leaderboard(df)
|
| 220 |
+
return fig, lb
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def plot_task3_by_grid() -> go.Figure:
|
| 224 |
+
results = _load_results()
|
| 225 |
+
df = results["compositional_distance"]
|
| 226 |
+
if df.empty:
|
| 227 |
+
return _make_empty_fig("No Task 3 results found.")
|
| 228 |
+
|
| 229 |
+
q2 = df[df["question_idx"] == 2].groupby(
|
| 230 |
+
["display_name", "grid_size"])["accuracy"].mean().reset_index()
|
| 231 |
+
|
| 232 |
+
fig = px.line(
|
| 233 |
+
q2, x="grid_size", y="accuracy", color="display_name",
|
| 234 |
+
markers=True,
|
| 235 |
+
labels={"grid_size": "Grid Size", "accuracy": "Q2 Accuracy",
|
| 236 |
+
"display_name": "Model"},
|
| 237 |
+
title="Task 3 — Q2 (Compositional) Accuracy by Grid Size",
|
| 238 |
+
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 239 |
+
)
|
| 240 |
+
fig.update_layout(yaxis_range=[0, 1], height=380)
|
| 241 |
+
return fig
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ---------------------------------------------------------------------------
|
| 245 |
+
# Run-experiments tab
|
| 246 |
+
# ---------------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
# Map from env-var name → user-provided key (populated at runtime from form)
|
| 249 |
+
_USER_KEYS: dict[str, str] = {}
|
| 250 |
+
_USER_KEYS_LOCK = threading.Lock()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def launch_experiments(
|
| 254 |
+
tasks: list[str],
|
| 255 |
+
models: list[str],
|
| 256 |
+
grid_sizes_str: str,
|
| 257 |
+
formats: list[str],
|
| 258 |
+
strategies: list[str],
|
| 259 |
+
gemini_key: str,
|
| 260 |
+
openai_key: str,
|
| 261 |
+
anthropic_key: str,
|
| 262 |
+
deepseek_key: str,
|
| 263 |
+
) -> tuple[str, list[list[str]]]:
|
| 264 |
+
"""Called when the user clicks 'Run' in the Gradio UI."""
|
| 265 |
+
# Build a key map from only what the user explicitly typed — never os.environ
|
| 266 |
+
user_keys: dict[str, str] = {}
|
| 267 |
+
if gemini_key.strip():
|
| 268 |
+
user_keys["GEMINI_API_KEY"] = gemini_key.strip()
|
| 269 |
+
if openai_key.strip():
|
| 270 |
+
user_keys["OPENAI_API_KEY"] = openai_key.strip()
|
| 271 |
+
if anthropic_key.strip():
|
| 272 |
+
user_keys["ANTHROPIC_API_KEY"] = anthropic_key.strip()
|
| 273 |
+
if deepseek_key.strip():
|
| 274 |
+
user_keys["DEEPSEEK_API_KEY"] = deepseek_key.strip()
|
| 275 |
+
|
| 276 |
+
if not user_keys:
|
| 277 |
+
return (
|
| 278 |
+
"No API keys provided. Please enter at least one API key to run experiments.",
|
| 279 |
+
[],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Parse grid sizes
|
| 283 |
+
try:
|
| 284 |
+
grid_sizes = [int(g.strip()) for g in grid_sizes_str.split(",") if g.strip()]
|
| 285 |
+
except ValueError:
|
| 286 |
+
return "Invalid grid sizes — enter comma-separated integers, e.g. 5,6,7", []
|
| 287 |
+
|
| 288 |
+
if not tasks:
|
| 289 |
+
return "Select at least one task.", []
|
| 290 |
+
if not models:
|
| 291 |
+
return "Select at least one model.", []
|
| 292 |
+
|
| 293 |
+
# Map display choices back to internal IDs
|
| 294 |
+
task_map = {
|
| 295 |
+
"Maze Navigation": "maze_navigation",
|
| 296 |
+
"Sequential Point Reuse": "point_reuse",
|
| 297 |
+
"Compositional Distance Comparison": "compositional_distance",
|
| 298 |
+
}
|
| 299 |
+
selected_tasks = [task_map[t] for t in tasks if t in task_map]
|
| 300 |
+
|
| 301 |
+
jobs = build_all_jobs(
|
| 302 |
+
cfg=CFG,
|
| 303 |
+
tasks=selected_tasks,
|
| 304 |
+
models=models,
|
| 305 |
+
grid_sizes=grid_sizes or None,
|
| 306 |
+
input_formats=formats or None,
|
| 307 |
+
prompt_strategies=strategies or None,
|
| 308 |
+
config_path=CONFIG_PATH,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if not jobs:
|
| 312 |
+
return "No jobs matched the selected filters.", []
|
| 313 |
+
|
| 314 |
+
launched = 0
|
| 315 |
+
skipped = 0
|
| 316 |
+
skipped_models: list[str] = []
|
| 317 |
+
with _monitor_lock:
|
| 318 |
+
for job in jobs:
|
| 319 |
+
# Only use the key the user provided — never fall back to server env
|
| 320 |
+
api_key = user_keys.get(job.api_key_env, "")
|
| 321 |
+
if not api_key:
|
| 322 |
+
skipped += 1
|
| 323 |
+
skipped_models.append(job.model)
|
| 324 |
+
continue
|
| 325 |
+
job.output_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
proc = submit_direct(
|
| 327 |
+
cmd=job.python_cmd,
|
| 328 |
+
working_dir=str(job.working_dir),
|
| 329 |
+
env={job.api_key_env: api_key},
|
| 330 |
+
)
|
| 331 |
+
_monitor.add_direct(
|
| 332 |
+
proc=proc,
|
| 333 |
+
label=job.label,
|
| 334 |
+
task_id=job.task_id,
|
| 335 |
+
model=job.model,
|
| 336 |
+
output_dir=str(job.output_dir),
|
| 337 |
+
)
|
| 338 |
+
launched += 1
|
| 339 |
+
time.sleep(1) # avoid API rate limits on burst start
|
| 340 |
+
|
| 341 |
+
status_msg = f"Launched {launched} job(s)."
|
| 342 |
+
if skipped:
|
| 343 |
+
missing = sorted(set(skipped_models))
|
| 344 |
+
status_msg += (
|
| 345 |
+
f" Skipped {skipped} job(s) for {', '.join(missing)} "
|
| 346 |
+
f"— no API key provided for those models."
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return status_msg, _monitor.as_table()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def refresh_status() -> tuple[list[list[str]], str]:
|
| 353 |
+
_monitor.refresh()
|
| 354 |
+
summary = _monitor.summary()
|
| 355 |
+
counts = summary["counts"]
|
| 356 |
+
msg = " ".join(f"{s}: {n}" for s, n in counts.items()) or "No jobs submitted yet."
|
| 357 |
+
return _monitor.as_table(), msg
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ---------------------------------------------------------------------------
|
| 361 |
+
# Gradio UI
|
| 362 |
+
# ---------------------------------------------------------------------------
|
| 363 |
+
|
| 364 |
+
PAPER_ABSTRACT = """
|
| 365 |
+
**Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks**
|
| 366 |
+
|
| 367 |
+
We systematically evaluate the spatial understanding of large language models through maze tasks—a
|
| 368 |
+
controlled testing context requiring multi-step planning and spatial abstraction. Across experiments
|
| 369 |
+
with Gemini-2.5-Flash, GPT-5-mini, Claude-Haiku-4.5, and DeepSeek-Chat, we uncover significant
|
| 370 |
+
discrepancies in spatial reasoning that challenge assumptions about LLM planning capabilities.
|
| 371 |
+
|
| 372 |
+
Key findings:
|
| 373 |
+
- **Representation sensitivity**: Gemini drops from 86% (raw tokenized) to 34% (visual grid) on 5×5 mazes with CoT
|
| 374 |
+
- **Prompting dependency**: Claude-Haiku fails completely without CoT, recovers to 78% with it
|
| 375 |
+
- **No spatial memory**: Models treat sequential questions independently, failing to reuse computed spatial knowledge
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
CSS = """
|
| 379 |
+
.leaderboard-table { font-size: 0.9em; }
|
| 380 |
+
.status-badge-running { color: #2196F3; font-weight: bold; }
|
| 381 |
+
.status-badge-completed { color: #4CAF50; font-weight: bold; }
|
| 382 |
+
.status-badge-failed { color: #F44336; font-weight: bold; }
|
| 383 |
+
footer { display: none !important; }
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
def build_ui() -> gr.Blocks:
|
| 387 |
+
with gr.Blocks(
|
| 388 |
+
title="SpatialBench — Do LLMs Build Spatial World Models?",
|
| 389 |
+
css=CSS,
|
| 390 |
+
theme=gr.themes.Soft(primary_hue="blue"),
|
| 391 |
+
) as demo:
|
| 392 |
+
|
| 393 |
+
gr.Markdown("# 🧩 SpatialBench")
|
| 394 |
+
gr.Markdown(
|
| 395 |
+
"**Evaluating Spatial World Models in Large Language Models** · "
|
| 396 |
+
"[Paper (ICLR 2026 Workshop)](https://arxiv.org/abs/...) · "
|
| 397 |
+
"[Code](https://github.com/...)"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
with gr.Tabs():
|
| 401 |
+
|
| 402 |
+
# ================================================================
|
| 403 |
+
# Tab 1: Leaderboard
|
| 404 |
+
# ================================================================
|
| 405 |
+
with gr.Tab("📊 Leaderboard"):
|
| 406 |
+
gr.Markdown(PAPER_ABSTRACT)
|
| 407 |
+
|
| 408 |
+
gr.Markdown("---")
|
| 409 |
+
gr.Markdown("## Task 1 — Maze Navigation (Planning)")
|
| 410 |
+
gr.Markdown(
|
| 411 |
+
"Models find shortest paths through mazes. "
|
| 412 |
+
"Two input formats: **raw** tokenized adjacency lists vs **visual** character grids."
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
with gr.Row():
|
| 416 |
+
t1_k = gr.Radio(
|
| 417 |
+
choices=[0, 3, 5], value=0, label="K-shot",
|
| 418 |
+
info="Number of in-context examples",
|
| 419 |
+
)
|
| 420 |
+
t1_fmt = gr.Radio(
|
| 421 |
+
choices=["raw", "visual"], value="raw", label="Input Format",
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
t1_plot = gr.Plot(label="Accuracy by Grid Size")
|
| 425 |
+
t1_lb = gr.Dataframe(
|
| 426 |
+
label="Leaderboard (mean accuracy across grid sizes)",
|
| 427 |
+
elem_classes=["leaderboard-table"],
|
| 428 |
+
)
|
| 429 |
+
t1_fmt_plot = gr.Plot(label="Raw vs Visual Format Comparison")
|
| 430 |
+
|
| 431 |
+
def refresh_task1(k, fmt):
|
| 432 |
+
fig, lb = plot_task1_accuracy(int(k), fmt)
|
| 433 |
+
fmt_fig = plot_task1_format_comparison()
|
| 434 |
+
return fig, lb, fmt_fig
|
| 435 |
+
|
| 436 |
+
for inp in [t1_k, t1_fmt]:
|
| 437 |
+
inp.change(
|
| 438 |
+
refresh_task1, inputs=[t1_k, t1_fmt],
|
| 439 |
+
outputs=[t1_plot, t1_lb, t1_fmt_plot],
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
gr.Markdown("---")
|
| 443 |
+
gr.Markdown("## Task 2 — Sequential Reasoning with Point Reuse")
|
| 444 |
+
gr.Markdown(
|
| 445 |
+
"Models answer 4 proximity questions. **Q3 = Q0** (same question repeated). "
|
| 446 |
+
"Do models reuse their earlier computation, or start from scratch?"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
t2_grid = gr.Slider(minimum=5, maximum=9, step=1, value=5,
|
| 450 |
+
label="Grid Size")
|
| 451 |
+
t2_plot = gr.Plot(label="Q0 vs Q3 Accuracy")
|
| 452 |
+
t2_grid_plot = gr.Plot(label="Q3 Accuracy Across Grid Sizes")
|
| 453 |
+
t2_lb = gr.Dataframe(
|
| 454 |
+
label="Leaderboard",
|
| 455 |
+
elem_classes=["leaderboard-table"],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def refresh_task2(gs):
|
| 459 |
+
fig, lb = plot_task2_q0_q3(int(gs))
|
| 460 |
+
grid_fig = plot_task2_by_grid()
|
| 461 |
+
return fig, grid_fig, lb
|
| 462 |
+
|
| 463 |
+
t2_grid.change(
|
| 464 |
+
refresh_task2, inputs=[t2_grid],
|
| 465 |
+
outputs=[t2_plot, t2_grid_plot, t2_lb],
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
gr.Markdown("---")
|
| 469 |
+
gr.Markdown("## Task 3 — Compositional Distance Comparison")
|
| 470 |
+
gr.Markdown(
|
| 471 |
+
"Models answer 3 questions about maze corners (A, B, C, D) and center M. "
|
| 472 |
+
"**Q2** (B→C) can potentially be composed from Q0 (A→M) and Q1 (D→M). "
|
| 473 |
+
"Δ = Q2 accuracy − avg(Q0, Q1)."
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
t3_plot = gr.Plot(label="Q0 / Q1 / Q2 Accuracy by Model")
|
| 477 |
+
t3_grid_plot = gr.Plot(label="Q2 Accuracy Across Grid Sizes")
|
| 478 |
+
t3_lb = gr.Dataframe(
|
| 479 |
+
label="Leaderboard (Δ shows compositional benefit)",
|
| 480 |
+
elem_classes=["leaderboard-table"],
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
with gr.Row():
|
| 484 |
+
refresh_lb_btn = gr.Button("🔄 Refresh Results", variant="secondary")
|
| 485 |
+
|
| 486 |
+
def refresh_all_leaderboard(_=None):
|
| 487 |
+
t1_fig, t1_table = plot_task1_accuracy(0, "raw")
|
| 488 |
+
t1_ff = plot_task1_format_comparison()
|
| 489 |
+
t2_fig, t2_lb_table = plot_task2_q0_q3(5)
|
| 490 |
+
t2_gfig = plot_task2_by_grid()
|
| 491 |
+
t3_fig, t3_lb_table = plot_task3_compositional()
|
| 492 |
+
t3_gfig = plot_task3_by_grid()
|
| 493 |
+
return (
|
| 494 |
+
t1_fig, t1_table, t1_ff,
|
| 495 |
+
t2_fig, t2_gfig, t2_lb_table,
|
| 496 |
+
t3_fig, t3_gfig, t3_lb_table,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
refresh_lb_btn.click(
|
| 500 |
+
refresh_all_leaderboard,
|
| 501 |
+
outputs=[
|
| 502 |
+
t1_plot, t1_lb, t1_fmt_plot,
|
| 503 |
+
t2_plot, t2_grid_plot, t2_lb,
|
| 504 |
+
t3_plot, t3_grid_plot, t3_lb,
|
| 505 |
+
],
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Initial load
|
| 509 |
+
demo.load(
|
| 510 |
+
refresh_all_leaderboard,
|
| 511 |
+
outputs=[
|
| 512 |
+
t1_plot, t1_lb, t1_fmt_plot,
|
| 513 |
+
t2_plot, t2_grid_plot, t2_lb,
|
| 514 |
+
t3_plot, t3_grid_plot, t3_lb,
|
| 515 |
+
],
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# ================================================================
|
| 519 |
+
# Tab 2: Run Experiments
|
| 520 |
+
# ================================================================
|
| 521 |
+
with gr.Tab("⚡ Run Experiments"):
|
| 522 |
+
gr.Markdown(
|
| 523 |
+
"## Launch Experiments\n"
|
| 524 |
+
"Experiments call LLM APIs directly — no compute cluster needed.\n\n"
|
| 525 |
+
"> **Your API keys are used only for your session and are never stored or logged.** \n"
|
| 526 |
+
"> Enter keys only for the model(s) you want to evaluate. "
|
| 527 |
+
"Jobs for models without a key will be skipped."
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
with gr.Row():
|
| 531 |
+
with gr.Column(scale=2):
|
| 532 |
+
# Task / model / grid selection
|
| 533 |
+
run_tasks = gr.CheckboxGroup(
|
| 534 |
+
choices=[
|
| 535 |
+
"Maze Navigation",
|
| 536 |
+
"Sequential Point Reuse",
|
| 537 |
+
"Compositional Distance Comparison",
|
| 538 |
+
],
|
| 539 |
+
value=["Maze Navigation"],
|
| 540 |
+
label="Tasks",
|
| 541 |
+
)
|
| 542 |
+
run_models = gr.CheckboxGroup(
|
| 543 |
+
choices=MODEL_CHOICES,
|
| 544 |
+
value=["gemini-2.5-flash"],
|
| 545 |
+
label="Models",
|
| 546 |
+
)
|
| 547 |
+
run_grids = gr.Textbox(
|
| 548 |
+
value="5,6,7",
|
| 549 |
+
label="Grid Sizes",
|
| 550 |
+
info="Comma-separated integers. Maze dataset supports 5–9 (and beyond if regenerated).",
|
| 551 |
+
)
|
| 552 |
+
with gr.Row():
|
| 553 |
+
run_formats = gr.CheckboxGroup(
|
| 554 |
+
choices=["raw", "visual"],
|
| 555 |
+
value=["raw"],
|
| 556 |
+
label="Input Formats (Task 1 only)",
|
| 557 |
+
)
|
| 558 |
+
run_strategies = gr.CheckboxGroup(
|
| 559 |
+
choices=["base", "cot", "reasoning"],
|
| 560 |
+
value=["cot"],
|
| 561 |
+
label="Prompt Strategies",
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
with gr.Column(scale=1):
|
| 565 |
+
gr.Markdown("### API Keys")
|
| 566 |
+
gr.Markdown(
|
| 567 |
+
"Enter the key(s) for the model(s) you selected. "
|
| 568 |
+
"Keys are used only for this session."
|
| 569 |
+
)
|
| 570 |
+
gemini_key = gr.Textbox(
|
| 571 |
+
label="GEMINI_API_KEY", type="password", placeholder="AIza...",
|
| 572 |
+
)
|
| 573 |
+
openai_key = gr.Textbox(
|
| 574 |
+
label="OPENAI_API_KEY", type="password", placeholder="sk-...",
|
| 575 |
+
)
|
| 576 |
+
anthropic_key = gr.Textbox(
|
| 577 |
+
label="ANTHROPIC_API_KEY", type="password", placeholder="sk-ant-...",
|
| 578 |
+
)
|
| 579 |
+
deepseek_key = gr.Textbox(
|
| 580 |
+
label="DEEPSEEK_API_KEY", type="password",
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
with gr.Row():
|
| 584 |
+
run_btn = gr.Button("🚀 Launch Experiments", variant="primary", scale=2)
|
| 585 |
+
refresh_btn = gr.Button("🔄 Refresh Status", scale=1)
|
| 586 |
+
|
| 587 |
+
launch_msg = gr.Textbox(label="Launch Status", interactive=False)
|
| 588 |
+
|
| 589 |
+
job_table = gr.Dataframe(
|
| 590 |
+
headers=["Task", "Model", "Label", "Status", "Elapsed", "Started"],
|
| 591 |
+
label="Job Status",
|
| 592 |
+
interactive=False,
|
| 593 |
+
wrap=True,
|
| 594 |
+
)
|
| 595 |
+
status_summary = gr.Textbox(
|
| 596 |
+
label="Summary", interactive=False,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
run_btn.click(
|
| 600 |
+
launch_experiments,
|
| 601 |
+
inputs=[
|
| 602 |
+
run_tasks, run_models, run_grids,
|
| 603 |
+
run_formats, run_strategies,
|
| 604 |
+
gemini_key, openai_key, anthropic_key, deepseek_key,
|
| 605 |
+
],
|
| 606 |
+
outputs=[launch_msg, job_table],
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
refresh_btn.click(
|
| 610 |
+
refresh_status,
|
| 611 |
+
outputs=[job_table, status_summary],
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# ================================================================
|
| 615 |
+
# Tab 3: About
|
| 616 |
+
# ================================================================
|
| 617 |
+
with gr.Tab("ℹ️ About"):
|
| 618 |
+
gr.Markdown("""
|
| 619 |
+
## About SpatialBench
|
| 620 |
+
|
| 621 |
+
SpatialBench is the evaluation platform accompanying the paper:
|
| 622 |
+
|
| 623 |
+
> **Do LLMs Build Spatial World Models? Evidence from Grid-World Maze Tasks**
|
| 624 |
+
> *Under review at ICLR 2026 Workshop*
|
| 625 |
+
|
| 626 |
+
### Three Tasks
|
| 627 |
+
|
| 628 |
+
| Task | Type | What it tests |
|
| 629 |
+
|------|------|---------------|
|
| 630 |
+
| **Task 1: Maze Navigation** | Planning | Find shortest path from start to goal |
|
| 631 |
+
| **Task 2: Sequential Point Reuse** | Reasoning | Reuse Q0 computation when Q3=Q0 |
|
| 632 |
+
| **Task 3: Compositional Distance** | Reasoning | Compose corner→center distances for Q2 |
|
| 633 |
+
|
| 634 |
+
### Input Representations
|
| 635 |
+
|
| 636 |
+
- **Raw (tokenized)**: `<ADJLIST_START> (0,0) <--> (0,1) ... <ADJLIST_END>`
|
| 637 |
+
- **Visual (grid)**: `Row 0: ['.', 'S', '.', '#'] Row 1: ['#', '.', '.', 'E']`
|
| 638 |
+
|
| 639 |
+
### Models Evaluated
|
| 640 |
+
|
| 641 |
+
| Model | Provider |
|
| 642 |
+
|-------|----------|
|
| 643 |
+
| Gemini 2.5 Flash | Google |
|
| 644 |
+
| GPT-5 Mini | OpenAI |
|
| 645 |
+
| Claude Haiku 4.5 | Anthropic |
|
| 646 |
+
| DeepSeek Chat | DeepSeek |
|
| 647 |
+
|
| 648 |
+
### Grid Sizes
|
| 649 |
+
|
| 650 |
+
Experiments run on n×n grids for n ∈ {5, 6, 7, 8, 9} by default.
|
| 651 |
+
The underlying `maze-dataset` library supports larger grids — adjust in the **Run** tab.
|
| 652 |
+
|
| 653 |
+
### Adding a New Model
|
| 654 |
+
|
| 655 |
+
Edit `pipeline/configs/experiments.yaml`:
|
| 656 |
+
```yaml
|
| 657 |
+
models:
|
| 658 |
+
your-model-id:
|
| 659 |
+
api_key_env: YOUR_API_KEY_ENV_VAR
|
| 660 |
+
display_name: "Your Model Name"
|
| 661 |
+
```
|
| 662 |
+
Then add inference support in `utils/llm_inference.py`.
|
| 663 |
+
|
| 664 |
+
### Citation
|
| 665 |
+
```bibtex
|
| 666 |
+
@inproceedings{spatialbench2026,
|
| 667 |
+
title = {Do {LLMs} Build Spatial World Models? Evidence from Grid-World Maze Tasks},
|
| 668 |
+
author = {Anonymous},
|
| 669 |
+
booktitle = {ICLR 2026 Workshop},
|
| 670 |
+
year = {2026},
|
| 671 |
+
}
|
| 672 |
+
```
|
| 673 |
+
""")
|
| 674 |
+
|
| 675 |
+
return demo
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# ---------------------------------------------------------------------------
|
| 679 |
+
# Entry point
|
| 680 |
+
# ---------------------------------------------------------------------------
|
| 681 |
+
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
demo = build_ui()
|
| 684 |
+
demo.launch(
|
| 685 |
+
server_name="0.0.0.0",
|
| 686 |
+
share=False,
|
| 687 |
+
show_error=True,
|
| 688 |
+
)
|
configs/experiments.yaml
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# SpatialBench Experiment Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# This file is the single source of truth for all experiments.
|
| 5 |
+
# Add a new model by adding an entry under `models`.
|
| 6 |
+
# Add a new grid size by extending `grid_sizes` in any task.
|
| 7 |
+
# All paths are relative to llm-maze-solver/ (the repo root).
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
# ---------------------------------------------------------------------------
|
| 11 |
+
# Global defaults — overridden per-task or per-experiment as needed
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
defaults:
|
| 14 |
+
n_test_mazes: 50
|
| 15 |
+
seed: 42
|
| 16 |
+
temperature: 0.1
|
| 17 |
+
max_tokens: 8192
|
| 18 |
+
sbatch:
|
| 19 |
+
cpus: 2
|
| 20 |
+
mem: "8G"
|
| 21 |
+
time: "10:00:00"
|
| 22 |
+
partition: "short"
|
| 23 |
+
log_dir: "maze-solver/eval_llm_logs"
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Models
|
| 27 |
+
# Each entry defines the model identifier used in API calls and the
|
| 28 |
+
# environment variable that must hold the API key.
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
models:
|
| 31 |
+
gemini-2.5-flash:
|
| 32 |
+
api_key_env: GEMINI_API_KEY
|
| 33 |
+
display_name: "Gemini 2.5 Flash"
|
| 34 |
+
gpt-5-mini:
|
| 35 |
+
api_key_env: OPENAI_API_KEY
|
| 36 |
+
display_name: "GPT-5 Mini"
|
| 37 |
+
claude-haiku-4-5:
|
| 38 |
+
api_key_env: ANTHROPIC_API_KEY
|
| 39 |
+
display_name: "Claude Haiku 4.5"
|
| 40 |
+
deepseek-chat:
|
| 41 |
+
api_key_env: DEEPSEEK_API_KEY
|
| 42 |
+
display_name: "DeepSeek Chat"
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Maze Navigation (Planning)
|
| 46 |
+
# Paper: Table 1, Table 5 (3-shot), Table 6 (5-shot)
|
| 47 |
+
# Script: maze-solver/eval_llm_maze_solver.py
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
maze_navigation:
|
| 50 |
+
description: >
|
| 51 |
+
Models find shortest paths through mazes represented in two formats
|
| 52 |
+
(raw tokenized adjacency lists vs visual character grids), tested
|
| 53 |
+
across k-shot settings and three prompting strategies.
|
| 54 |
+
script: "maze-solver/eval_llm_maze_solver.py"
|
| 55 |
+
working_dir: "maze-solver"
|
| 56 |
+
output_base: "maze-solver/llm-maze-evaluation-results"
|
| 57 |
+
|
| 58 |
+
# Grid sizes: paper used 5-9; extend freely up to maze-dataset limits
|
| 59 |
+
grid_sizes: [5, 6, 7, 8, 9]
|
| 60 |
+
|
| 61 |
+
# Input representations
|
| 62 |
+
input_formats: ["raw", "visual"]
|
| 63 |
+
|
| 64 |
+
# Prompting strategies (maps to script flags)
|
| 65 |
+
prompt_strategies:
|
| 66 |
+
base:
|
| 67 |
+
flags: []
|
| 68 |
+
display_name: "Base"
|
| 69 |
+
cot:
|
| 70 |
+
flags: ["--chain_of_thought"]
|
| 71 |
+
display_name: "Chain-of-Thought"
|
| 72 |
+
reasoning:
|
| 73 |
+
flags: ["--reasoning"]
|
| 74 |
+
display_name: "Post-hoc Reasoning"
|
| 75 |
+
|
| 76 |
+
# K-shot values tested simultaneously in one script run
|
| 77 |
+
k_shots: "0,3,5"
|
| 78 |
+
|
| 79 |
+
# Fixed params
|
| 80 |
+
maze_type: "cycles"
|
| 81 |
+
percolation_p: 0.2
|
| 82 |
+
visualize: true
|
| 83 |
+
|
| 84 |
+
sbatch:
|
| 85 |
+
time: "10:00:00"
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Sequential Reasoning with Point Reuse (Q3 = Q0)
|
| 89 |
+
# Paper: Table 2, Table 7
|
| 90 |
+
# Script: maze-solver/spatial_reasoning/eval_proximity_comparison.py
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
point_reuse:
|
| 93 |
+
description: >
|
| 94 |
+
Models answer four sequential proximity questions about the same maze.
|
| 95 |
+
Q3 is identical to Q0, probing whether models reuse previously
|
| 96 |
+
computed spatial information or treat each question independently.
|
| 97 |
+
script: "maze-solver/spatial_reasoning/eval_proximity_comparison.py"
|
| 98 |
+
working_dir: "spatial_reasoning/spatial_reasoning_experiments"
|
| 99 |
+
|
| 100 |
+
# Paper used 5-9; extend freely
|
| 101 |
+
grid_sizes: [5, 6, 7, 8, 9]
|
| 102 |
+
|
| 103 |
+
input_format: "raw"
|
| 104 |
+
strategy: "point_reuse"
|
| 105 |
+
reuse_pattern: "last_first_same"
|
| 106 |
+
n_questions_per_maze: 4
|
| 107 |
+
sequential_questions: true
|
| 108 |
+
|
| 109 |
+
# Prompting strategies
|
| 110 |
+
prompt_strategies:
|
| 111 |
+
base:
|
| 112 |
+
prompt_type: "baseline"
|
| 113 |
+
display_name: "Base"
|
| 114 |
+
cot:
|
| 115 |
+
prompt_type: "cot"
|
| 116 |
+
display_name: "Chain-of-Thought"
|
| 117 |
+
reasoning:
|
| 118 |
+
prompt_type: "reasoning"
|
| 119 |
+
display_name: "Post-hoc Reasoning"
|
| 120 |
+
|
| 121 |
+
output_base: "spatial_reasoning/spatial-reasoning-results-point-reuse-q3-q0"
|
| 122 |
+
visualize: true
|
| 123 |
+
save_details: true
|
| 124 |
+
|
| 125 |
+
sbatch:
|
| 126 |
+
time: "10:30:00"
|
| 127 |
+
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
# Compositional Distance Comparison
|
| 130 |
+
# Paper: Table 3, Table 8, Table 9
|
| 131 |
+
# Script: maze-solver/spatial_reasoning/eval_extended_experiments.py
|
| 132 |
+
# Corner pattern: corners_to_center (Q0: top-left→center,
|
| 133 |
+
# Q1: bottom-right→center,
|
| 134 |
+
# Q2: corner→corner compositional)
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
compositional_distance:
|
| 137 |
+
description: >
|
| 138 |
+
Models answer three questions about maze corners (A=top-left,
|
| 139 |
+
B=top-right, C=bottom-left, D=bottom-right) and center M.
|
| 140 |
+
Q2 can be composed from information established in Q0 and Q1,
|
| 141 |
+
probing whether models build cumulative spatial knowledge.
|
| 142 |
+
script: "maze-solver/spatial_reasoning/eval_extended_experiments.py"
|
| 143 |
+
working_dir: "spatial_reasoning/spatial_reasoning_experiments"
|
| 144 |
+
|
| 145 |
+
# Paper reported 5-9 in Tables 8/9; scripts originally only ran 5-7
|
| 146 |
+
# Extended to match paper
|
| 147 |
+
grid_sizes: [5, 6, 7, 8, 9]
|
| 148 |
+
|
| 149 |
+
input_format: "raw"
|
| 150 |
+
strategy: "orthogonal"
|
| 151 |
+
corner_pattern: "corners_to_center" # matches paper Q0/Q1/Q2 design
|
| 152 |
+
n_questions_per_maze: 3
|
| 153 |
+
|
| 154 |
+
# Prompting strategies
|
| 155 |
+
prompt_strategies:
|
| 156 |
+
base:
|
| 157 |
+
prompt_type: "baseline"
|
| 158 |
+
display_name: "Base"
|
| 159 |
+
cot:
|
| 160 |
+
prompt_type: "cot"
|
| 161 |
+
display_name: "Chain-of-Thought"
|
| 162 |
+
reasoning:
|
| 163 |
+
prompt_type: "reasoning"
|
| 164 |
+
display_name: "Post-hoc Reasoning"
|
| 165 |
+
|
| 166 |
+
output_base: "spatial_reasoning/spatial-reasoning-results-orthogonal"
|
| 167 |
+
visualize: true
|
| 168 |
+
save_details: true
|
| 169 |
+
|
| 170 |
+
sbatch:
|
| 171 |
+
time: "06:30:00"
|
pipeline/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"SpatialBench pipeline modules."
|
pipeline/job_monitor.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
job_monitor.py
|
| 3 |
+
--------------
|
| 4 |
+
Tracks and displays the status of experiment jobs.
|
| 5 |
+
|
| 6 |
+
Supports two backends:
|
| 7 |
+
- SLURM : polls `squeue` for cluster jobs (used when running locally)
|
| 8 |
+
- Direct : tracks subprocess-launched jobs (used when running via API keys
|
| 9 |
+
on HF Space or without a cluster)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import subprocess
|
| 15 |
+
import time
|
| 16 |
+
import threading
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import Callable
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Status model
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
class JobStatus(str, Enum):
|
| 28 |
+
PENDING = "pending"
|
| 29 |
+
RUNNING = "running"
|
| 30 |
+
COMPLETED = "completed"
|
| 31 |
+
FAILED = "failed"
|
| 32 |
+
CANCELLED = "cancelled"
|
| 33 |
+
UNKNOWN = "unknown"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class JobRecord:
|
| 38 |
+
job_id: str # SLURM job-id or process PID (as str)
|
| 39 |
+
label: str # human-readable experiment label
|
| 40 |
+
task_id: str
|
| 41 |
+
model: str
|
| 42 |
+
status: JobStatus = JobStatus.PENDING
|
| 43 |
+
submitted_at: datetime = field(default_factory=datetime.now)
|
| 44 |
+
finished_at: datetime | None = None
|
| 45 |
+
output_dir: str = ""
|
| 46 |
+
log_out: str = ""
|
| 47 |
+
log_err: str = ""
|
| 48 |
+
|
| 49 |
+
def elapsed(self) -> str:
|
| 50 |
+
end = self.finished_at or datetime.now()
|
| 51 |
+
secs = int((end - self.submitted_at).total_seconds())
|
| 52 |
+
h, rem = divmod(secs, 3600)
|
| 53 |
+
m, s = divmod(rem, 60)
|
| 54 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 55 |
+
|
| 56 |
+
def as_dict(self) -> dict:
|
| 57 |
+
return {
|
| 58 |
+
"job_id": self.job_id,
|
| 59 |
+
"label": self.label,
|
| 60 |
+
"task_id": self.task_id,
|
| 61 |
+
"model": self.model,
|
| 62 |
+
"status": self.status.value,
|
| 63 |
+
"submitted_at": self.submitted_at.strftime("%Y-%m-%d %H:%M:%S"),
|
| 64 |
+
"elapsed": self.elapsed(),
|
| 65 |
+
"output_dir": self.output_dir,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# SLURM monitor
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
# Map squeue state codes → JobStatus
|
| 74 |
+
_SLURM_STATE_MAP = {
|
| 75 |
+
"PD": JobStatus.PENDING,
|
| 76 |
+
"R": JobStatus.RUNNING,
|
| 77 |
+
"CG": JobStatus.RUNNING,
|
| 78 |
+
"CD": JobStatus.COMPLETED,
|
| 79 |
+
"F": JobStatus.FAILED,
|
| 80 |
+
"CA": JobStatus.CANCELLED,
|
| 81 |
+
"TO": JobStatus.FAILED,
|
| 82 |
+
"OOM": JobStatus.FAILED,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _query_slurm(job_ids: list[str]) -> dict[str, JobStatus]:
|
| 87 |
+
"""Return {job_id: JobStatus} for a batch of SLURM job ids."""
|
| 88 |
+
if not job_ids:
|
| 89 |
+
return {}
|
| 90 |
+
try:
|
| 91 |
+
result = subprocess.run(
|
| 92 |
+
["squeue", "--jobs", ",".join(job_ids), "--format=%i %t", "--noheader"],
|
| 93 |
+
capture_output=True, text=True, timeout=15,
|
| 94 |
+
)
|
| 95 |
+
statuses: dict[str, JobStatus] = {}
|
| 96 |
+
for line in result.stdout.strip().splitlines():
|
| 97 |
+
parts = line.split()
|
| 98 |
+
if len(parts) >= 2:
|
| 99 |
+
jid, state = parts[0], parts[1]
|
| 100 |
+
statuses[jid] = _SLURM_STATE_MAP.get(state, JobStatus.UNKNOWN)
|
| 101 |
+
return statuses
|
| 102 |
+
except Exception:
|
| 103 |
+
return {}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def submit_sbatch(script_text: str, script_path: str) -> str | None:
|
| 107 |
+
"""Write script_text to script_path, submit via sbatch, return job_id."""
|
| 108 |
+
with open(script_path, "w") as f:
|
| 109 |
+
f.write(script_text)
|
| 110 |
+
try:
|
| 111 |
+
result = subprocess.run(
|
| 112 |
+
["sbatch", script_path],
|
| 113 |
+
capture_output=True, text=True, timeout=30,
|
| 114 |
+
)
|
| 115 |
+
# sbatch output: "Submitted batch job 12345"
|
| 116 |
+
for token in result.stdout.split():
|
| 117 |
+
if token.isdigit():
|
| 118 |
+
return token
|
| 119 |
+
except Exception:
|
| 120 |
+
pass
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Direct (subprocess) monitor
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
def submit_direct(
|
| 129 |
+
cmd: list[str],
|
| 130 |
+
working_dir: str,
|
| 131 |
+
env: dict | None = None,
|
| 132 |
+
on_finish: Callable[[int], None] | None = None,
|
| 133 |
+
) -> subprocess.Popen:
|
| 134 |
+
"""Launch a job as a subprocess and optionally call on_finish(returncode)."""
|
| 135 |
+
import os
|
| 136 |
+
proc_env = os.environ.copy()
|
| 137 |
+
if env:
|
| 138 |
+
proc_env.update(env)
|
| 139 |
+
|
| 140 |
+
proc = subprocess.Popen(
|
| 141 |
+
cmd,
|
| 142 |
+
cwd=working_dir,
|
| 143 |
+
env=proc_env,
|
| 144 |
+
stdout=subprocess.PIPE,
|
| 145 |
+
stderr=subprocess.PIPE,
|
| 146 |
+
text=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if on_finish:
|
| 150 |
+
def _wait():
|
| 151 |
+
proc.wait()
|
| 152 |
+
on_finish(proc.returncode)
|
| 153 |
+
threading.Thread(target=_wait, daemon=True).start()
|
| 154 |
+
|
| 155 |
+
return proc
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
# JobMonitor — unified tracker
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
class JobMonitor:
|
| 163 |
+
"""
|
| 164 |
+
Tracks a collection of JobRecords.
|
| 165 |
+
|
| 166 |
+
Usage (SLURM):
|
| 167 |
+
monitor = JobMonitor(mode="slurm")
|
| 168 |
+
record = monitor.add(job_id="12345", label="Task1|gemini|raw|cot", ...)
|
| 169 |
+
monitor.refresh() # updates statuses from squeue
|
| 170 |
+
monitor.wait_all() # blocks until all done
|
| 171 |
+
|
| 172 |
+
Usage (direct):
|
| 173 |
+
monitor = JobMonitor(mode="direct")
|
| 174 |
+
proc = submit_direct(cmd, wdir, env)
|
| 175 |
+
record = monitor.add_direct(proc, label=..., ...)
|
| 176 |
+
monitor.wait_all()
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, mode: str = "slurm"):
|
| 180 |
+
assert mode in ("slurm", "direct")
|
| 181 |
+
self.mode = mode
|
| 182 |
+
self._records: dict[str, JobRecord] = {} # job_id → JobRecord
|
| 183 |
+
self._procs: dict[str, subprocess.Popen] = {}
|
| 184 |
+
self._lock = threading.Lock()
|
| 185 |
+
|
| 186 |
+
# -- adding jobs --------------------------------------------------------
|
| 187 |
+
|
| 188 |
+
def add(
|
| 189 |
+
self,
|
| 190 |
+
job_id: str,
|
| 191 |
+
label: str,
|
| 192 |
+
task_id: str,
|
| 193 |
+
model: str,
|
| 194 |
+
output_dir: str = "",
|
| 195 |
+
log_out: str = "",
|
| 196 |
+
log_err: str = "",
|
| 197 |
+
) -> JobRecord:
|
| 198 |
+
record = JobRecord(
|
| 199 |
+
job_id=job_id, label=label,
|
| 200 |
+
task_id=task_id, model=model,
|
| 201 |
+
output_dir=output_dir, log_out=log_out, log_err=log_err,
|
| 202 |
+
)
|
| 203 |
+
with self._lock:
|
| 204 |
+
self._records[job_id] = record
|
| 205 |
+
return record
|
| 206 |
+
|
| 207 |
+
def add_direct(
|
| 208 |
+
self,
|
| 209 |
+
proc: subprocess.Popen,
|
| 210 |
+
label: str,
|
| 211 |
+
task_id: str,
|
| 212 |
+
model: str,
|
| 213 |
+
output_dir: str = "",
|
| 214 |
+
) -> JobRecord:
|
| 215 |
+
job_id = str(proc.pid)
|
| 216 |
+
record = self.add(
|
| 217 |
+
job_id=job_id, label=label,
|
| 218 |
+
task_id=task_id, model=model, output_dir=output_dir,
|
| 219 |
+
)
|
| 220 |
+
record.status = JobStatus.RUNNING
|
| 221 |
+
with self._lock:
|
| 222 |
+
self._procs[job_id] = proc
|
| 223 |
+
|
| 224 |
+
def _monitor():
|
| 225 |
+
proc.wait()
|
| 226 |
+
with self._lock:
|
| 227 |
+
record.status = (
|
| 228 |
+
JobStatus.COMPLETED if proc.returncode == 0
|
| 229 |
+
else JobStatus.FAILED
|
| 230 |
+
)
|
| 231 |
+
record.finished_at = datetime.now()
|
| 232 |
+
|
| 233 |
+
threading.Thread(target=_monitor, daemon=True).start()
|
| 234 |
+
return record
|
| 235 |
+
|
| 236 |
+
# -- status refreshing --------------------------------------------------
|
| 237 |
+
|
| 238 |
+
def refresh(self) -> None:
|
| 239 |
+
"""Update statuses from SLURM (no-op for direct mode)."""
|
| 240 |
+
if self.mode != "slurm":
|
| 241 |
+
return
|
| 242 |
+
with self._lock:
|
| 243 |
+
active_ids = [
|
| 244 |
+
jid for jid, r in self._records.items()
|
| 245 |
+
if r.status in (JobStatus.PENDING, JobStatus.RUNNING)
|
| 246 |
+
]
|
| 247 |
+
if not active_ids:
|
| 248 |
+
return
|
| 249 |
+
statuses = _query_slurm(active_ids)
|
| 250 |
+
with self._lock:
|
| 251 |
+
for jid, status in statuses.items():
|
| 252 |
+
if jid in self._records:
|
| 253 |
+
old = self._records[jid].status
|
| 254 |
+
self._records[jid].status = status
|
| 255 |
+
if old != status and status in (
|
| 256 |
+
JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED
|
| 257 |
+
):
|
| 258 |
+
self._records[jid].finished_at = datetime.now()
|
| 259 |
+
# Any job no longer appearing in squeue is done
|
| 260 |
+
for jid in active_ids:
|
| 261 |
+
if jid not in statuses and jid in self._records:
|
| 262 |
+
r = self._records[jid]
|
| 263 |
+
if r.status == JobStatus.RUNNING:
|
| 264 |
+
r.status = JobStatus.COMPLETED
|
| 265 |
+
r.finished_at = datetime.now()
|
| 266 |
+
|
| 267 |
+
def wait_all(self, poll_interval: int = 30, callback: Callable | None = None) -> None:
|
| 268 |
+
"""Block until all jobs are in a terminal state."""
|
| 269 |
+
while True:
|
| 270 |
+
self.refresh()
|
| 271 |
+
active = self.active_jobs()
|
| 272 |
+
if callback:
|
| 273 |
+
callback(self.summary())
|
| 274 |
+
if not active:
|
| 275 |
+
break
|
| 276 |
+
time.sleep(poll_interval)
|
| 277 |
+
|
| 278 |
+
# -- queries ------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
def active_jobs(self) -> list[JobRecord]:
|
| 281 |
+
with self._lock:
|
| 282 |
+
return [
|
| 283 |
+
r for r in self._records.values()
|
| 284 |
+
if r.status in (JobStatus.PENDING, JobStatus.RUNNING)
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
def all_records(self) -> list[JobRecord]:
|
| 288 |
+
with self._lock:
|
| 289 |
+
return list(self._records.values())
|
| 290 |
+
|
| 291 |
+
def summary(self) -> dict:
|
| 292 |
+
records = self.all_records()
|
| 293 |
+
counts: dict[str, int] = {}
|
| 294 |
+
for r in records:
|
| 295 |
+
counts[r.status.value] = counts.get(r.status.value, 0) + 1
|
| 296 |
+
return {
|
| 297 |
+
"total": len(records),
|
| 298 |
+
"counts": counts,
|
| 299 |
+
"records": [r.as_dict() for r in records],
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
def as_table(self) -> list[list[str]]:
|
| 303 |
+
"""Return rows suitable for a Gradio Dataframe component."""
|
| 304 |
+
records = self.all_records()
|
| 305 |
+
rows = []
|
| 306 |
+
for r in records:
|
| 307 |
+
rows.append([
|
| 308 |
+
r.task_id, r.model, r.label,
|
| 309 |
+
r.status.value, r.elapsed(),
|
| 310 |
+
r.submitted_at.strftime("%H:%M:%S"),
|
| 311 |
+
])
|
| 312 |
+
return rows
|
| 313 |
+
|
| 314 |
+
TABLE_HEADERS = ["Task", "Model", "Label", "Status", "Elapsed", "Started"]
|
pipeline/results_loader.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
results_loader.py
|
| 3 |
+
-----------------
|
| 4 |
+
Scans experiment output directories and assembles results into pandas
|
| 5 |
+
DataFrames ready for display in the Gradio leaderboard.
|
| 6 |
+
|
| 7 |
+
Output directory conventions (from experiments.yaml):
|
| 8 |
+
Task 1: <output_base>/<model>/<fmt>_input_<strat>/
|
| 9 |
+
→ results_{grid_size}x{grid_size}_k{k}.csv OR summary.json
|
| 10 |
+
Task 2: <output_base>/<model>/point_reuse_q3q0_<strat>/
|
| 11 |
+
→ proximity_comparison_results.csv
|
| 12 |
+
Task 3: <output_base>/<model>/orthogonal_corners_to_center_<strat>/
|
| 13 |
+
→ results.csv OR summary_stats.json
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import yaml
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Config helpers
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def load_config(config_path: str | Path) -> dict:
|
| 31 |
+
with open(config_path) as f:
|
| 32 |
+
return yaml.safe_load(f)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _model_display(cfg: dict, model_id: str) -> str:
|
| 36 |
+
return cfg["models"].get(model_id, {}).get("display_name", model_id)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Task 1 results
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def load_maze_navigation_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
|
| 44 |
+
"""
|
| 45 |
+
Scan Task 1 output dirs and return a DataFrame with columns:
|
| 46 |
+
model, display_name, input_format, prompt_strategy, grid_size, k_shot, accuracy
|
| 47 |
+
"""
|
| 48 |
+
task = cfg["maze_navigation"]
|
| 49 |
+
base = repo_root / task["output_base"]
|
| 50 |
+
rows = []
|
| 51 |
+
|
| 52 |
+
for model_id, model_meta in cfg["models"].items():
|
| 53 |
+
display = model_meta["display_name"]
|
| 54 |
+
for fmt in task["input_formats"]:
|
| 55 |
+
for strat in task["prompt_strategies"]:
|
| 56 |
+
subdir = base / model_id.replace(".", "_").replace("-", "_") / f"{fmt}_input_{strat}"
|
| 57 |
+
if not subdir.exists():
|
| 58 |
+
continue
|
| 59 |
+
# Look for summary JSON first, then CSVs
|
| 60 |
+
summary_file = subdir / "summary.json"
|
| 61 |
+
if summary_file.exists():
|
| 62 |
+
_parse_task1_summary(summary_file, rows, model_id, display, fmt, strat)
|
| 63 |
+
else:
|
| 64 |
+
# Fall back to per-grid CSVs
|
| 65 |
+
for csv_file in sorted(subdir.glob("*.csv")):
|
| 66 |
+
_parse_task1_csv(csv_file, rows, model_id, display, fmt, strat)
|
| 67 |
+
|
| 68 |
+
if not rows:
|
| 69 |
+
return pd.DataFrame(columns=[
|
| 70 |
+
"model", "display_name", "input_format", "prompt_strategy",
|
| 71 |
+
"grid_size", "k_shot", "accuracy"
|
| 72 |
+
])
|
| 73 |
+
return pd.DataFrame(rows)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _parse_task1_summary(path: Path, rows: list, model_id, display, fmt, strat):
|
| 77 |
+
try:
|
| 78 |
+
with open(path) as f:
|
| 79 |
+
data = json.load(f)
|
| 80 |
+
# Expected: {grid_size: {k: accuracy, ...}, ...}
|
| 81 |
+
for grid_key, k_dict in data.items():
|
| 82 |
+
try:
|
| 83 |
+
grid_size = int(str(grid_key).replace("x", "").split("_")[0])
|
| 84 |
+
except ValueError:
|
| 85 |
+
continue
|
| 86 |
+
if isinstance(k_dict, dict):
|
| 87 |
+
for k, acc in k_dict.items():
|
| 88 |
+
rows.append({
|
| 89 |
+
"model": model_id, "display_name": display,
|
| 90 |
+
"input_format": fmt, "prompt_strategy": strat,
|
| 91 |
+
"grid_size": grid_size, "k_shot": int(k),
|
| 92 |
+
"accuracy": float(acc),
|
| 93 |
+
})
|
| 94 |
+
elif isinstance(k_dict, (int, float)):
|
| 95 |
+
rows.append({
|
| 96 |
+
"model": model_id, "display_name": display,
|
| 97 |
+
"input_format": fmt, "prompt_strategy": strat,
|
| 98 |
+
"grid_size": grid_size, "k_shot": 0,
|
| 99 |
+
"accuracy": float(k_dict),
|
| 100 |
+
})
|
| 101 |
+
except Exception:
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _parse_task1_csv(path: Path, rows: list, model_id, display, fmt, strat):
|
| 106 |
+
try:
|
| 107 |
+
df = pd.read_csv(path)
|
| 108 |
+
# Detect grid_size and k_shot from filename or columns
|
| 109 |
+
grid_size = None
|
| 110 |
+
k_shot = 0
|
| 111 |
+
name = path.stem
|
| 112 |
+
for part in name.split("_"):
|
| 113 |
+
if part.startswith("k") and part[1:].isdigit():
|
| 114 |
+
k_shot = int(part[1:])
|
| 115 |
+
if "x" in part:
|
| 116 |
+
try:
|
| 117 |
+
g = int(part.split("x")[0])
|
| 118 |
+
grid_size = g
|
| 119 |
+
except ValueError:
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
if "grid_size" in df.columns:
|
| 123 |
+
for gs, gdf in df.groupby("grid_size"):
|
| 124 |
+
acc = _df_accuracy(gdf)
|
| 125 |
+
rows.append({
|
| 126 |
+
"model": model_id, "display_name": display,
|
| 127 |
+
"input_format": fmt, "prompt_strategy": strat,
|
| 128 |
+
"grid_size": int(gs), "k_shot": k_shot,
|
| 129 |
+
"accuracy": acc,
|
| 130 |
+
})
|
| 131 |
+
elif grid_size is not None:
|
| 132 |
+
rows.append({
|
| 133 |
+
"model": model_id, "display_name": display,
|
| 134 |
+
"input_format": fmt, "prompt_strategy": strat,
|
| 135 |
+
"grid_size": grid_size, "k_shot": k_shot,
|
| 136 |
+
"accuracy": _df_accuracy(df),
|
| 137 |
+
})
|
| 138 |
+
except Exception:
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _df_accuracy(df: pd.DataFrame) -> float:
|
| 143 |
+
for col in ("is_correct", "exact_match", "correct", "accuracy"):
|
| 144 |
+
if col in df.columns:
|
| 145 |
+
return float(df[col].mean())
|
| 146 |
+
return float("nan")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# Task 2 results
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def load_point_reuse_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
|
| 154 |
+
"""
|
| 155 |
+
Return DataFrame with columns:
|
| 156 |
+
model, display_name, prompt_strategy, grid_size, question_idx, accuracy
|
| 157 |
+
"""
|
| 158 |
+
task = cfg["point_reuse"]
|
| 159 |
+
base = repo_root / task["output_base"]
|
| 160 |
+
rows = []
|
| 161 |
+
|
| 162 |
+
for model_id, model_meta in cfg["models"].items():
|
| 163 |
+
display = model_meta["display_name"]
|
| 164 |
+
for strat, strat_cfg in task["prompt_strategies"].items():
|
| 165 |
+
subdir = (
|
| 166 |
+
base
|
| 167 |
+
/ model_id.replace(".", "_").replace("-", "_")
|
| 168 |
+
/ f"point_reuse_q3q0_{strat}"
|
| 169 |
+
)
|
| 170 |
+
if not subdir.exists():
|
| 171 |
+
# Also try the pattern used by existing scripts
|
| 172 |
+
subdir = base / model_id / f"proximity_comparison_point_reuse_last_first_same_{strat_cfg['prompt_type']}"
|
| 173 |
+
if not subdir.exists():
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
csv_files = list(subdir.glob("*.csv"))
|
| 177 |
+
for csv_file in csv_files:
|
| 178 |
+
try:
|
| 179 |
+
df = pd.read_csv(csv_file)
|
| 180 |
+
if "grid_size" not in df.columns:
|
| 181 |
+
continue
|
| 182 |
+
q_col = next(
|
| 183 |
+
(c for c in ("question_idx", "question_index", "q_idx") if c in df.columns),
|
| 184 |
+
None,
|
| 185 |
+
)
|
| 186 |
+
for gs, gdf in df.groupby("grid_size"):
|
| 187 |
+
if q_col:
|
| 188 |
+
for qi, qdf in gdf.groupby(q_col):
|
| 189 |
+
rows.append({
|
| 190 |
+
"model": model_id, "display_name": display,
|
| 191 |
+
"prompt_strategy": strat,
|
| 192 |
+
"grid_size": int(gs),
|
| 193 |
+
"question_idx": int(qi),
|
| 194 |
+
"accuracy": _df_accuracy(qdf),
|
| 195 |
+
})
|
| 196 |
+
else:
|
| 197 |
+
rows.append({
|
| 198 |
+
"model": model_id, "display_name": display,
|
| 199 |
+
"prompt_strategy": strat,
|
| 200 |
+
"grid_size": int(gs),
|
| 201 |
+
"question_idx": -1,
|
| 202 |
+
"accuracy": _df_accuracy(gdf),
|
| 203 |
+
})
|
| 204 |
+
except Exception:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
if not rows:
|
| 208 |
+
return pd.DataFrame(columns=[
|
| 209 |
+
"model", "display_name", "prompt_strategy",
|
| 210 |
+
"grid_size", "question_idx", "accuracy"
|
| 211 |
+
])
|
| 212 |
+
return pd.DataFrame(rows)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
# Task 3 results
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
|
| 219 |
+
def load_compositional_distance_results(cfg: dict, repo_root: Path) -> pd.DataFrame:
|
| 220 |
+
"""
|
| 221 |
+
Return DataFrame with columns:
|
| 222 |
+
model, display_name, prompt_strategy, grid_size, question_idx, accuracy, delta
|
| 223 |
+
"""
|
| 224 |
+
task = cfg["compositional_distance"]
|
| 225 |
+
base = repo_root / task["output_base"]
|
| 226 |
+
rows = []
|
| 227 |
+
|
| 228 |
+
for model_id, model_meta in cfg["models"].items():
|
| 229 |
+
display = model_meta["display_name"]
|
| 230 |
+
for strat, strat_cfg in task["prompt_strategies"].items():
|
| 231 |
+
tag = f"orthogonal_{task['corner_pattern']}_{strat}"
|
| 232 |
+
subdir = (
|
| 233 |
+
base
|
| 234 |
+
/ model_id.replace(".", "_").replace("-", "_")
|
| 235 |
+
/ tag
|
| 236 |
+
)
|
| 237 |
+
if not subdir.exists():
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# Prefer summary_stats.json
|
| 241 |
+
stats_file = subdir / "summary_stats.json"
|
| 242 |
+
if stats_file.exists():
|
| 243 |
+
try:
|
| 244 |
+
with open(stats_file) as f:
|
| 245 |
+
data = json.load(f)
|
| 246 |
+
_parse_task3_stats(data, rows, model_id, display, strat)
|
| 247 |
+
continue
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
# Fall back to results.csv
|
| 252 |
+
for csv_file in sorted(subdir.glob("*.csv")):
|
| 253 |
+
try:
|
| 254 |
+
df = pd.read_csv(csv_file)
|
| 255 |
+
if "grid_size" not in df.columns:
|
| 256 |
+
continue
|
| 257 |
+
q_col = next(
|
| 258 |
+
(c for c in ("question_idx", "question_index") if c in df.columns),
|
| 259 |
+
None,
|
| 260 |
+
)
|
| 261 |
+
for gs, gdf in df.groupby("grid_size"):
|
| 262 |
+
if q_col:
|
| 263 |
+
q_accs = {}
|
| 264 |
+
for qi, qdf in gdf.groupby(q_col):
|
| 265 |
+
acc = _df_accuracy(qdf)
|
| 266 |
+
q_accs[int(qi)] = acc
|
| 267 |
+
rows.append({
|
| 268 |
+
"model": model_id, "display_name": display,
|
| 269 |
+
"prompt_strategy": strat,
|
| 270 |
+
"grid_size": int(gs),
|
| 271 |
+
"question_idx": int(qi),
|
| 272 |
+
"accuracy": acc,
|
| 273 |
+
"delta": float("nan"),
|
| 274 |
+
})
|
| 275 |
+
# Compute delta for Q2 vs avg(Q0, Q1)
|
| 276 |
+
if 0 in q_accs and 1 in q_accs and 2 in q_accs:
|
| 277 |
+
delta = q_accs[2] - (q_accs[0] + q_accs[1]) / 2
|
| 278 |
+
for r in rows:
|
| 279 |
+
if (r["model"] == model_id and
|
| 280 |
+
r["prompt_strategy"] == strat and
|
| 281 |
+
r["grid_size"] == int(gs) and
|
| 282 |
+
r["question_idx"] == 2):
|
| 283 |
+
r["delta"] = round(delta, 4)
|
| 284 |
+
except Exception:
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
if not rows:
|
| 288 |
+
return pd.DataFrame(columns=[
|
| 289 |
+
"model", "display_name", "prompt_strategy",
|
| 290 |
+
"grid_size", "question_idx", "accuracy", "delta"
|
| 291 |
+
])
|
| 292 |
+
return pd.DataFrame(rows)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _parse_task3_stats(data: dict, rows: list, model_id, display, strat):
|
| 296 |
+
"""Parse summary_stats.json for task3."""
|
| 297 |
+
try:
|
| 298 |
+
by_q = data.get("accuracy_by_question", data.get("per_question", {}))
|
| 299 |
+
by_gs = data.get("accuracy_by_grid_size", {})
|
| 300 |
+
for gs_key, gs_data in by_gs.items():
|
| 301 |
+
try:
|
| 302 |
+
gs = int(str(gs_key).replace("x", "").split("_")[0])
|
| 303 |
+
except ValueError:
|
| 304 |
+
continue
|
| 305 |
+
if isinstance(gs_data, dict):
|
| 306 |
+
q_accs = {}
|
| 307 |
+
for qi_key, acc in gs_data.items():
|
| 308 |
+
try:
|
| 309 |
+
qi = int(qi_key)
|
| 310 |
+
q_accs[qi] = float(acc)
|
| 311 |
+
rows.append({
|
| 312 |
+
"model": model_id, "display_name": display,
|
| 313 |
+
"prompt_strategy": strat,
|
| 314 |
+
"grid_size": gs, "question_idx": qi,
|
| 315 |
+
"accuracy": float(acc), "delta": float("nan"),
|
| 316 |
+
})
|
| 317 |
+
except (ValueError, TypeError):
|
| 318 |
+
pass
|
| 319 |
+
if 0 in q_accs and 1 in q_accs and 2 in q_accs:
|
| 320 |
+
delta = q_accs[2] - (q_accs[0] + q_accs[1]) / 2
|
| 321 |
+
for r in rows:
|
| 322 |
+
if (r["model"] == model_id and r["prompt_strategy"] == strat
|
| 323 |
+
and r["grid_size"] == gs and r["question_idx"] == 2):
|
| 324 |
+
r["delta"] = round(delta, 4)
|
| 325 |
+
except Exception:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ---------------------------------------------------------------------------
|
| 330 |
+
# Leaderboard aggregators
|
| 331 |
+
# ---------------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
def maze_navigation_leaderboard(df: pd.DataFrame, k_shot: int = 0) -> pd.DataFrame:
|
| 334 |
+
"""
|
| 335 |
+
Pivot Task 1 results into a leaderboard table.
|
| 336 |
+
Rows = models, columns = (format × strategy), values = accuracy at k_shot.
|
| 337 |
+
"""
|
| 338 |
+
if df.empty:
|
| 339 |
+
return pd.DataFrame()
|
| 340 |
+
sub = df[df["k_shot"] == k_shot]
|
| 341 |
+
if sub.empty:
|
| 342 |
+
return pd.DataFrame()
|
| 343 |
+
pivot = sub.pivot_table(
|
| 344 |
+
index=["display_name"],
|
| 345 |
+
columns=["input_format", "prompt_strategy"],
|
| 346 |
+
values="accuracy",
|
| 347 |
+
aggfunc="mean",
|
| 348 |
+
)
|
| 349 |
+
pivot.columns = [f"{fmt}_{strat}" for fmt, strat in pivot.columns]
|
| 350 |
+
pivot = pivot.reset_index().rename(columns={"display_name": "Model"})
|
| 351 |
+
return pivot.round(3)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def point_reuse_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
|
| 355 |
+
"""
|
| 356 |
+
Task 2 leaderboard: per-model accuracy at Q0 and Q3 across all grid sizes.
|
| 357 |
+
Highlights Q3 vs Q0 consistency.
|
| 358 |
+
"""
|
| 359 |
+
if df.empty:
|
| 360 |
+
return pd.DataFrame()
|
| 361 |
+
q0 = df[df["question_idx"] == 0].groupby("display_name")["accuracy"].mean().rename("Q0 acc")
|
| 362 |
+
q3 = df[df["question_idx"] == 3].groupby("display_name")["accuracy"].mean().rename("Q3 acc")
|
| 363 |
+
out = pd.concat([q0, q3], axis=1).reset_index().rename(columns={"display_name": "Model"})
|
| 364 |
+
out["Q3-Q0 diff"] = (out["Q3 acc"] - out["Q0 acc"]).round(3)
|
| 365 |
+
return out.round(3)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def compositional_distance_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
|
| 369 |
+
"""
|
| 370 |
+
Task 3 leaderboard: per-model Q0/Q1/Q2 accuracy + delta (Q2 vs avg Q0/Q1).
|
| 371 |
+
"""
|
| 372 |
+
if df.empty:
|
| 373 |
+
return pd.DataFrame()
|
| 374 |
+
rows = []
|
| 375 |
+
for model, mdf in df.groupby("display_name"):
|
| 376 |
+
q0 = mdf[mdf["question_idx"] == 0]["accuracy"].mean()
|
| 377 |
+
q1 = mdf[mdf["question_idx"] == 1]["accuracy"].mean()
|
| 378 |
+
q2 = mdf[mdf["question_idx"] == 2]["accuracy"].mean()
|
| 379 |
+
delta = q2 - (q0 + q1) / 2 if not (pd.isna(q0) or pd.isna(q1) or pd.isna(q2)) else float("nan")
|
| 380 |
+
rows.append({
|
| 381 |
+
"Model": model,
|
| 382 |
+
"Q0 (A→M)": round(q0, 3),
|
| 383 |
+
"Q1 (D→M)": round(q1, 3),
|
| 384 |
+
"Q2 (B→C)": round(q2, 3),
|
| 385 |
+
"Δ Q2 vs avg(Q0,Q1)": round(delta, 3),
|
| 386 |
+
})
|
| 387 |
+
return pd.DataFrame(rows)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# ---------------------------------------------------------------------------
|
| 391 |
+
# Full results loader (called by app.py)
|
| 392 |
+
# ---------------------------------------------------------------------------
|
| 393 |
+
|
| 394 |
+
def load_all_results(config_path: str | Path) -> dict[str, pd.DataFrame]:
|
| 395 |
+
"""Load results for all three tasks. Returns dict of DataFrames."""
|
| 396 |
+
cfg = load_config(config_path)
|
| 397 |
+
repo_root = Path(config_path).parent.parent.parent # pipeline/configs/.. → llm-maze-solver
|
| 398 |
+
return {
|
| 399 |
+
"maze_navigation": load_maze_navigation_results(cfg, repo_root),
|
| 400 |
+
"point_reuse": load_point_reuse_results(cfg, repo_root),
|
| 401 |
+
"compositional_distance": load_compositional_distance_results(cfg, repo_root),
|
| 402 |
+
}
|
pipeline/task_builder.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
task_builder.py
|
| 3 |
+
---------------
|
| 4 |
+
Translates experiments.yaml into concrete shell commands (direct or sbatch).
|
| 5 |
+
|
| 6 |
+
Each public function returns a list of ExperimentJob dataclasses, one per
|
| 7 |
+
(model × format × prompt_strategy × grid_sizes) combination. The caller
|
| 8 |
+
decides whether to run them directly or wrap them in sbatch.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import itertools
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Data structures
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ExperimentJob:
|
| 27 |
+
"""A single runnable experiment unit."""
|
| 28 |
+
task_id: str # e.g. "maze_navigation"
|
| 29 |
+
model: str # e.g. "gemini-2.5-flash"
|
| 30 |
+
label: str # human-readable label for this job
|
| 31 |
+
working_dir: Path # where to cd before running
|
| 32 |
+
python_cmd: list[str] # [python, script.py, --arg, value, ...]
|
| 33 |
+
api_key_env: str # env-var name that must be set
|
| 34 |
+
output_dir: Path # where results land
|
| 35 |
+
sbatch_cfg: dict # mem, time, cpus, partition, log_dir
|
| 36 |
+
grid_sizes: list[int] # for display / filtering
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Config loader
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def load_config(config_path: str | Path) -> dict:
|
| 44 |
+
with open(config_path) as f:
|
| 45 |
+
return yaml.safe_load(f)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _repo_root(config_path: Path) -> Path:
|
| 49 |
+
"""pipeline/configs/experiments.yaml → llm-maze-solver/"""
|
| 50 |
+
return config_path.parent.parent.parent
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Internal helpers
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def _merge_sbatch(defaults: dict, override: dict) -> dict:
|
| 58 |
+
merged = dict(defaults)
|
| 59 |
+
merged.update(override)
|
| 60 |
+
return merged
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _grid_str(grid_sizes: list[int]) -> str:
|
| 64 |
+
return ",".join(str(g) for g in grid_sizes)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _output_subdir(base: str, model: str, tag: str) -> str:
|
| 68 |
+
"""Produce a deterministic output subdirectory path."""
|
| 69 |
+
return f"{base}/{model.replace('.', '_').replace('-', '_')}/{tag}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Maze Navigation
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def build_maze_navigation_jobs(
|
| 77 |
+
cfg: dict,
|
| 78 |
+
models: list[str] | None = None,
|
| 79 |
+
grid_sizes: list[int] | None = None,
|
| 80 |
+
input_formats: list[str] | None = None,
|
| 81 |
+
prompt_strategies: list[str] | None = None,
|
| 82 |
+
config_path: Path = None,
|
| 83 |
+
) -> list[ExperimentJob]:
|
| 84 |
+
"""Build jobs for Maze Navigation (planning, k-shot)."""
|
| 85 |
+
task = cfg["maze_navigation"]
|
| 86 |
+
defaults = cfg["defaults"]
|
| 87 |
+
model_cfg = cfg["models"]
|
| 88 |
+
|
| 89 |
+
selected_models = models or list(model_cfg.keys())
|
| 90 |
+
selected_formats = input_formats or task["input_formats"]
|
| 91 |
+
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
|
| 92 |
+
selected_grids = grid_sizes or task["grid_sizes"]
|
| 93 |
+
|
| 94 |
+
repo = _repo_root(config_path) if config_path else Path(".")
|
| 95 |
+
script = repo / task["script"]
|
| 96 |
+
wdir = repo / task["working_dir"]
|
| 97 |
+
|
| 98 |
+
jobs: list[ExperimentJob] = []
|
| 99 |
+
|
| 100 |
+
for model, fmt, strat in itertools.product(
|
| 101 |
+
selected_models, selected_formats, selected_strategies
|
| 102 |
+
):
|
| 103 |
+
if model not in model_cfg:
|
| 104 |
+
continue
|
| 105 |
+
strat_cfg = task["prompt_strategies"][strat]
|
| 106 |
+
tag = f"{fmt}_input_{strat}"
|
| 107 |
+
out_dir = repo / _output_subdir(task["output_base"], model, tag)
|
| 108 |
+
|
| 109 |
+
cmd = [
|
| 110 |
+
"python", str(script),
|
| 111 |
+
"--model_name", model,
|
| 112 |
+
"--input_format", fmt,
|
| 113 |
+
"--k_shots", task["k_shots"],
|
| 114 |
+
"--n_test_mazes", str(cfg["defaults"]["n_test_mazes"]),
|
| 115 |
+
"--test_grid_sizes", _grid_str(selected_grids),
|
| 116 |
+
"--maze_type", task["maze_type"],
|
| 117 |
+
"--seed", str(defaults["seed"]),
|
| 118 |
+
"--output_dir", str(out_dir),
|
| 119 |
+
]
|
| 120 |
+
for flag in strat_cfg["flags"]:
|
| 121 |
+
cmd.append(flag)
|
| 122 |
+
if task.get("visualize"):
|
| 123 |
+
cmd.append("--visualize")
|
| 124 |
+
|
| 125 |
+
jobs.append(ExperimentJob(
|
| 126 |
+
task_id="maze_navigation",
|
| 127 |
+
model=model,
|
| 128 |
+
label=f"Maze Navigation | {model} | {fmt} | {strat}",
|
| 129 |
+
working_dir=wdir,
|
| 130 |
+
python_cmd=cmd,
|
| 131 |
+
api_key_env=model_cfg[model]["api_key_env"],
|
| 132 |
+
output_dir=out_dir,
|
| 133 |
+
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
|
| 134 |
+
grid_sizes=selected_grids,
|
| 135 |
+
))
|
| 136 |
+
|
| 137 |
+
return jobs
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
# Sequential Reasoning with Point Reuse (Q3 = Q0)
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
def build_point_reuse_jobs(
|
| 145 |
+
cfg: dict,
|
| 146 |
+
models: list[str] | None = None,
|
| 147 |
+
grid_sizes: list[int] | None = None,
|
| 148 |
+
prompt_strategies: list[str] | None = None,
|
| 149 |
+
config_path: Path = None,
|
| 150 |
+
) -> list[ExperimentJob]:
|
| 151 |
+
"""Build jobs for Sequential Reasoning with Point Reuse (Q3=Q0)."""
|
| 152 |
+
task = cfg["point_reuse"]
|
| 153 |
+
defaults = cfg["defaults"]
|
| 154 |
+
model_cfg = cfg["models"]
|
| 155 |
+
|
| 156 |
+
selected_models = models or list(model_cfg.keys())
|
| 157 |
+
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
|
| 158 |
+
selected_grids = grid_sizes or task["grid_sizes"]
|
| 159 |
+
|
| 160 |
+
repo = _repo_root(config_path) if config_path else Path(".")
|
| 161 |
+
script = repo / task["script"]
|
| 162 |
+
wdir = repo / task["working_dir"]
|
| 163 |
+
|
| 164 |
+
jobs: list[ExperimentJob] = []
|
| 165 |
+
|
| 166 |
+
for model, strat in itertools.product(selected_models, selected_strategies):
|
| 167 |
+
if model not in model_cfg:
|
| 168 |
+
continue
|
| 169 |
+
strat_cfg = task["prompt_strategies"][strat]
|
| 170 |
+
tag = f"point_reuse_q3q0_{strat}"
|
| 171 |
+
out_dir = repo / _output_subdir(task["output_base"], model, tag)
|
| 172 |
+
|
| 173 |
+
cmd = [
|
| 174 |
+
"python", str(script),
|
| 175 |
+
"--model_name", model,
|
| 176 |
+
"--input_format", task["input_format"],
|
| 177 |
+
"--strategy", task["strategy"],
|
| 178 |
+
"--reuse_pattern", task["reuse_pattern"],
|
| 179 |
+
"--prompt_type", strat_cfg["prompt_type"],
|
| 180 |
+
"--n_questions_per_maze", str(task["n_questions_per_maze"]),
|
| 181 |
+
"--n_test_mazes", str(defaults["n_test_mazes"]),
|
| 182 |
+
"--test_grid_sizes", _grid_str(selected_grids),
|
| 183 |
+
"--output_dir", str(out_dir),
|
| 184 |
+
]
|
| 185 |
+
if task.get("sequential_questions"):
|
| 186 |
+
cmd.append("--sequential_questions")
|
| 187 |
+
if task.get("visualize"):
|
| 188 |
+
cmd.append("--visualize")
|
| 189 |
+
if task.get("save_details"):
|
| 190 |
+
cmd.append("--save_details")
|
| 191 |
+
|
| 192 |
+
jobs.append(ExperimentJob(
|
| 193 |
+
task_id="point_reuse",
|
| 194 |
+
model=model,
|
| 195 |
+
label=f"Point Reuse | {model} | {strat}",
|
| 196 |
+
working_dir=wdir,
|
| 197 |
+
python_cmd=cmd,
|
| 198 |
+
api_key_env=model_cfg[model]["api_key_env"],
|
| 199 |
+
output_dir=out_dir,
|
| 200 |
+
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
|
| 201 |
+
grid_sizes=selected_grids,
|
| 202 |
+
))
|
| 203 |
+
|
| 204 |
+
return jobs
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
# Compositional Distance Comparison
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
def build_compositional_distance_jobs(
|
| 212 |
+
cfg: dict,
|
| 213 |
+
models: list[str] | None = None,
|
| 214 |
+
grid_sizes: list[int] | None = None,
|
| 215 |
+
prompt_strategies: list[str] | None = None,
|
| 216 |
+
config_path: Path = None,
|
| 217 |
+
) -> list[ExperimentJob]:
|
| 218 |
+
"""Build jobs for Compositional Distance Comparison (corners-to-center)."""
|
| 219 |
+
task = cfg["compositional_distance"]
|
| 220 |
+
defaults = cfg["defaults"]
|
| 221 |
+
model_cfg = cfg["models"]
|
| 222 |
+
|
| 223 |
+
selected_models = models or list(model_cfg.keys())
|
| 224 |
+
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
|
| 225 |
+
selected_grids = grid_sizes or task["grid_sizes"]
|
| 226 |
+
|
| 227 |
+
repo = _repo_root(config_path) if config_path else Path(".")
|
| 228 |
+
script = repo / task["script"]
|
| 229 |
+
wdir = repo / task["working_dir"]
|
| 230 |
+
|
| 231 |
+
jobs: list[ExperimentJob] = []
|
| 232 |
+
|
| 233 |
+
for model, strat in itertools.product(selected_models, selected_strategies):
|
| 234 |
+
if model not in model_cfg:
|
| 235 |
+
continue
|
| 236 |
+
strat_cfg = task["prompt_strategies"][strat]
|
| 237 |
+
tag = f"orthogonal_{task['corner_pattern']}_{strat}"
|
| 238 |
+
out_dir = repo / _output_subdir(task["output_base"], model, tag)
|
| 239 |
+
|
| 240 |
+
cmd = [
|
| 241 |
+
"python", str(script),
|
| 242 |
+
"--model_name", model,
|
| 243 |
+
"--input_format", task["input_format"],
|
| 244 |
+
"--strategy", task["strategy"],
|
| 245 |
+
"--corner_pattern", task["corner_pattern"],
|
| 246 |
+
"--prompt_type", strat_cfg["prompt_type"],
|
| 247 |
+
"--n_questions_per_maze", str(task["n_questions_per_maze"]),
|
| 248 |
+
"--n_test_mazes", str(defaults["n_test_mazes"]),
|
| 249 |
+
"--test_grid_sizes", _grid_str(selected_grids),
|
| 250 |
+
"--output_dir", str(out_dir),
|
| 251 |
+
]
|
| 252 |
+
if task.get("visualize"):
|
| 253 |
+
cmd.append("--visualize")
|
| 254 |
+
if task.get("save_details"):
|
| 255 |
+
cmd.append("--save_details")
|
| 256 |
+
|
| 257 |
+
jobs.append(ExperimentJob(
|
| 258 |
+
task_id="compositional_distance",
|
| 259 |
+
model=model,
|
| 260 |
+
label=f"Compositional Distance | {model} | {strat}",
|
| 261 |
+
working_dir=wdir,
|
| 262 |
+
python_cmd=cmd,
|
| 263 |
+
api_key_env=model_cfg[model]["api_key_env"],
|
| 264 |
+
output_dir=out_dir,
|
| 265 |
+
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
|
| 266 |
+
grid_sizes=selected_grids,
|
| 267 |
+
))
|
| 268 |
+
|
| 269 |
+
return jobs
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# Unified builder
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
|
| 276 |
+
def build_all_jobs(
|
| 277 |
+
cfg: dict,
|
| 278 |
+
tasks: list[str] | None = None,
|
| 279 |
+
models: list[str] | None = None,
|
| 280 |
+
grid_sizes: list[int] | None = None,
|
| 281 |
+
input_formats: list[str] | None = None,
|
| 282 |
+
prompt_strategies: list[str] | None = None,
|
| 283 |
+
config_path: Path = None,
|
| 284 |
+
) -> list[ExperimentJob]:
|
| 285 |
+
"""Build jobs for all requested tasks."""
|
| 286 |
+
selected_tasks = tasks or ["maze_navigation", "point_reuse", "compositional_distance"]
|
| 287 |
+
jobs: list[ExperimentJob] = []
|
| 288 |
+
kw = dict(
|
| 289 |
+
models=models,
|
| 290 |
+
grid_sizes=grid_sizes,
|
| 291 |
+
prompt_strategies=prompt_strategies,
|
| 292 |
+
config_path=config_path,
|
| 293 |
+
)
|
| 294 |
+
if "maze_navigation" in selected_tasks:
|
| 295 |
+
jobs += build_maze_navigation_jobs(cfg, input_formats=input_formats, **kw)
|
| 296 |
+
if "point_reuse" in selected_tasks:
|
| 297 |
+
jobs += build_point_reuse_jobs(cfg, **kw)
|
| 298 |
+
if "compositional_distance" in selected_tasks:
|
| 299 |
+
jobs += build_compositional_distance_jobs(cfg, **kw)
|
| 300 |
+
return jobs
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# ---------------------------------------------------------------------------
|
| 304 |
+
# sbatch script generator
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
|
| 307 |
+
def make_sbatch_script(job: ExperimentJob, log_dir: Path) -> str:
|
| 308 |
+
"""Return the text of an sbatch submission script for a job."""
|
| 309 |
+
s = job.sbatch_cfg
|
| 310 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 311 |
+
safe_label = job.label.replace(" ", "_").replace("|", "").replace("/", "_")
|
| 312 |
+
|
| 313 |
+
lines = [
|
| 314 |
+
"#!/bin/bash",
|
| 315 |
+
f"#SBATCH -c {s.get('cpus', 2)}",
|
| 316 |
+
f"#SBATCH -t {s.get('time', '10:00:00')}",
|
| 317 |
+
f"#SBATCH -p {s.get('partition', 'short')}",
|
| 318 |
+
f"#SBATCH --mem={s.get('mem', '8G')}",
|
| 319 |
+
f"#SBATCH -o {log_dir}/{safe_label}_%j.out",
|
| 320 |
+
f"#SBATCH -e {log_dir}/{safe_label}_%j.err",
|
| 321 |
+
"",
|
| 322 |
+
f"# {job.label}",
|
| 323 |
+
f"export {job.api_key_env}=${{{job.api_key_env}}}",
|
| 324 |
+
"",
|
| 325 |
+
f"cd {job.working_dir}",
|
| 326 |
+
" \\\n ".join(job.python_cmd),
|
| 327 |
+
]
|
| 328 |
+
return "\n".join(lines) + "\n"
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SpatialBench — pipeline dependencies
|
| 2 |
+
# Install with: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# Gradio UI (HuggingFace Space entrypoint)
|
| 5 |
+
gradio>=4.20.0
|
| 6 |
+
|
| 7 |
+
# Plotting
|
| 8 |
+
plotly>=5.18.0
|
| 9 |
+
|
| 10 |
+
# Data
|
| 11 |
+
pandas>=2.0.0
|
| 12 |
+
numpy>=1.24.0
|
| 13 |
+
|
| 14 |
+
# Config parsing
|
| 15 |
+
PyYAML>=6.0
|
| 16 |
+
|
| 17 |
+
# LLM API clients
|
| 18 |
+
openai>=1.14.0
|
| 19 |
+
anthropic>=0.25.0
|
| 20 |
+
google-generativeai>=0.5.0
|
| 21 |
+
|
| 22 |
+
# (DeepSeek uses the OpenAI-compatible client — no extra package needed)
|
| 23 |
+
|
| 24 |
+
# Sentence embeddings for reasoning quality analysis
|
| 25 |
+
sentence-transformers>=2.6.0
|
| 26 |
+
|
| 27 |
+
# ROUGE for reasoning quality analysis
|
| 28 |
+
rouge-score>=0.1.2
|
| 29 |
+
|
| 30 |
+
# Environment variable loading
|
| 31 |
+
python-dotenv>=1.0.0
|
run_experiments.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
run_experiments.py
|
| 4 |
+
------------------
|
| 5 |
+
CLI orchestrator for SpatialBench experiments.
|
| 6 |
+
|
| 7 |
+
Run on the cluster with SLURM:
|
| 8 |
+
python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm
|
| 9 |
+
|
| 10 |
+
Run directly (uses API keys, no SLURM required):
|
| 11 |
+
python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct
|
| 12 |
+
|
| 13 |
+
Dry-run (print commands without executing):
|
| 14 |
+
python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run
|
| 15 |
+
|
| 16 |
+
Filter experiments:
|
| 17 |
+
python run_experiments.py --tasks maze_navigation \\
|
| 18 |
+
--models gemini-2.5-flash claude-haiku-4-5 \\
|
| 19 |
+
--grid-sizes 5 6 7 \\
|
| 20 |
+
--formats raw \\
|
| 21 |
+
--strategies cot reasoning
|
| 22 |
+
|
| 23 |
+
Show status of running SLURM jobs:
|
| 24 |
+
python run_experiments.py --status
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import json
|
| 31 |
+
import os
|
| 32 |
+
import subprocess
|
| 33 |
+
import sys
|
| 34 |
+
import tempfile
|
| 35 |
+
import time
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
# Load .env if present (before importing pipeline modules)
|
| 39 |
+
_env_file = Path(__file__).parent / ".env"
|
| 40 |
+
if _env_file.exists():
|
| 41 |
+
with open(_env_file) as _f:
|
| 42 |
+
for _line in _f:
|
| 43 |
+
_line = _line.strip()
|
| 44 |
+
if _line and not _line.startswith("#") and "=" in _line:
|
| 45 |
+
_k, _v = _line.split("=", 1)
|
| 46 |
+
os.environ.setdefault(_k.strip(), _v.strip())
|
| 47 |
+
|
| 48 |
+
from pipeline.task_builder import (
|
| 49 |
+
load_config, build_all_jobs, make_sbatch_script, ExperimentJob,
|
| 50 |
+
)
|
| 51 |
+
from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct
|
| 52 |
+
|
| 53 |
+
CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
|
| 54 |
+
REPO_ROOT = CONFIG_PATH.parent.parent.parent # llm-maze-solver/
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# Helpers
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
def _check_api_key(job: ExperimentJob) -> bool:
|
| 62 |
+
val = os.environ.get(job.api_key_env, "")
|
| 63 |
+
if not val:
|
| 64 |
+
print(f" [WARN] {job.api_key_env} not set — skipping: {job.label}")
|
| 65 |
+
return False
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _print_job(job: ExperimentJob) -> None:
|
| 70 |
+
print(f"\n {job.label}")
|
| 71 |
+
print(f" cmd : {' '.join(job.python_cmd[:4])} ...")
|
| 72 |
+
print(f" wdir: {job.working_dir}")
|
| 73 |
+
print(f" out : {job.output_dir}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# Run modes
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
|
| 81 |
+
log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs"
|
| 82 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
for job in jobs:
|
| 85 |
+
if not _check_api_key(job):
|
| 86 |
+
continue
|
| 87 |
+
script_text = make_sbatch_script(job, log_dir)
|
| 88 |
+
if dry_run:
|
| 89 |
+
_print_job(job)
|
| 90 |
+
print(" --- sbatch script ---")
|
| 91 |
+
print(script_text)
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
with tempfile.NamedTemporaryFile(
|
| 95 |
+
mode="w", suffix=".sh", prefix="spatialbench_",
|
| 96 |
+
dir=log_dir, delete=False
|
| 97 |
+
) as tmp:
|
| 98 |
+
tmp.write(script_text)
|
| 99 |
+
script_path = tmp.name
|
| 100 |
+
|
| 101 |
+
job_id = submit_sbatch(script_text, script_path)
|
| 102 |
+
if job_id:
|
| 103 |
+
monitor.add(
|
| 104 |
+
job_id=job_id,
|
| 105 |
+
label=job.label,
|
| 106 |
+
task_id=job.task_id,
|
| 107 |
+
model=job.model,
|
| 108 |
+
output_dir=str(job.output_dir),
|
| 109 |
+
log_out=str(log_dir / f"{job_id}.out"),
|
| 110 |
+
log_err=str(log_dir / f"{job_id}.err"),
|
| 111 |
+
)
|
| 112 |
+
print(f" Submitted {job.label} → SLURM job {job_id}")
|
| 113 |
+
else:
|
| 114 |
+
print(f" [ERROR] Failed to submit: {job.label}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
|
| 118 |
+
for job in jobs:
|
| 119 |
+
if not _check_api_key(job):
|
| 120 |
+
continue
|
| 121 |
+
if dry_run:
|
| 122 |
+
_print_job(job)
|
| 123 |
+
print(f" cmd: {' '.join(job.python_cmd)}\n")
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
env_patch = {job.api_key_env: os.environ[job.api_key_env]}
|
| 127 |
+
job.output_dir.mkdir(parents=True, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
print(f" Starting: {job.label}")
|
| 130 |
+
proc = submit_direct(
|
| 131 |
+
cmd=job.python_cmd,
|
| 132 |
+
working_dir=str(job.working_dir),
|
| 133 |
+
env=env_patch,
|
| 134 |
+
)
|
| 135 |
+
monitor.add_direct(
|
| 136 |
+
proc=proc,
|
| 137 |
+
label=job.label,
|
| 138 |
+
task_id=job.task_id,
|
| 139 |
+
model=job.model,
|
| 140 |
+
output_dir=str(job.output_dir),
|
| 141 |
+
)
|
| 142 |
+
# Small gap to avoid hammering APIs simultaneously
|
| 143 |
+
time.sleep(2)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# Status display
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def show_status(monitor: JobMonitor) -> None:
|
| 151 |
+
monitor.refresh()
|
| 152 |
+
summary = monitor.summary()
|
| 153 |
+
print(f"\nTotal jobs: {summary['total']}")
|
| 154 |
+
for status, count in summary["counts"].items():
|
| 155 |
+
print(f" {status:12s}: {count}")
|
| 156 |
+
print()
|
| 157 |
+
for r in summary["records"]:
|
| 158 |
+
print(f" [{r['status']:9s}] {r['label']:<60s} elapsed: {r['elapsed']}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
# Argument parsing
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
|
| 165 |
+
def parse_args() -> argparse.Namespace:
|
| 166 |
+
parser = argparse.ArgumentParser(
|
| 167 |
+
description="SpatialBench experiment orchestrator",
|
| 168 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 169 |
+
epilog=__doc__,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--tasks", nargs="+",
|
| 174 |
+
default=["maze_navigation", "point_reuse", "compositional_distance"],
|
| 175 |
+
choices=["maze_navigation", "point_reuse", "compositional_distance"],
|
| 176 |
+
help="Which tasks to run (default: all three)",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--models", nargs="+", default=None,
|
| 180 |
+
help="Model IDs to run (default: all models in config)",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--grid-sizes", nargs="+", type=int, default=None,
|
| 184 |
+
dest="grid_sizes",
|
| 185 |
+
help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--formats", nargs="+", default=None,
|
| 189 |
+
choices=["raw", "visual"],
|
| 190 |
+
help="Input formats for Task 1 (default: both raw and visual)",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--strategies", nargs="+", default=None,
|
| 194 |
+
choices=["base", "cot", "reasoning"],
|
| 195 |
+
help="Prompt strategies (default: all)",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--mode", default="slurm", choices=["slurm", "direct"],
|
| 199 |
+
help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--dry-run", action="store_true",
|
| 203 |
+
help="Print commands without executing them",
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--no-wait", action="store_true",
|
| 207 |
+
help="Return immediately after submission (don't poll for completion)",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--status", action="store_true",
|
| 211 |
+
help="Query and display SLURM job status (requires --job-ids or a running monitor)",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--job-ids", nargs="+", default=None,
|
| 215 |
+
help="SLURM job IDs to check status for (used with --status)",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--config", default=str(CONFIG_PATH),
|
| 219 |
+
help=f"Path to experiments.yaml (default: {CONFIG_PATH})",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--poll-interval", type=int, default=60,
|
| 223 |
+
dest="poll_interval",
|
| 224 |
+
help="Seconds between SLURM status polls when waiting (default: 60)",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return parser.parse_args()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
# Main
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
|
| 234 |
+
def main() -> None:
|
| 235 |
+
args = parse_args()
|
| 236 |
+
cfg = load_config(args.config)
|
| 237 |
+
|
| 238 |
+
# Status-only mode
|
| 239 |
+
if args.status:
|
| 240 |
+
monitor = JobMonitor(mode="slurm")
|
| 241 |
+
if args.job_ids:
|
| 242 |
+
for jid in args.job_ids:
|
| 243 |
+
monitor.add(job_id=jid, label=jid, task_id="?", model="?")
|
| 244 |
+
show_status(monitor)
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
# Build jobs
|
| 248 |
+
jobs = build_all_jobs(
|
| 249 |
+
cfg=cfg,
|
| 250 |
+
tasks=args.tasks,
|
| 251 |
+
models=args.models,
|
| 252 |
+
grid_sizes=args.grid_sizes,
|
| 253 |
+
input_formats=args.formats,
|
| 254 |
+
prompt_strategies=args.strategies,
|
| 255 |
+
config_path=Path(args.config),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if not jobs:
|
| 259 |
+
print("No jobs matched the requested filters.")
|
| 260 |
+
return
|
| 261 |
+
|
| 262 |
+
print(f"\nSpatialBench — {len(jobs)} job(s) to run")
|
| 263 |
+
print(f" mode : {args.mode}")
|
| 264 |
+
print(f" tasks : {args.tasks}")
|
| 265 |
+
print(f" models : {args.models or 'all'}")
|
| 266 |
+
print(f" grids : {args.grid_sizes or 'per-task default'}")
|
| 267 |
+
print(f" formats : {args.formats or 'per-task default'}")
|
| 268 |
+
print(f" strategies: {args.strategies or 'all'}")
|
| 269 |
+
print(f" dry-run : {args.dry_run}")
|
| 270 |
+
print()
|
| 271 |
+
|
| 272 |
+
monitor = JobMonitor(mode=args.mode)
|
| 273 |
+
|
| 274 |
+
if args.mode == "slurm":
|
| 275 |
+
run_slurm(jobs, monitor, dry_run=args.dry_run)
|
| 276 |
+
else:
|
| 277 |
+
run_direct(jobs, monitor, dry_run=args.dry_run)
|
| 278 |
+
|
| 279 |
+
if args.dry_run or args.no_wait:
|
| 280 |
+
if not args.dry_run:
|
| 281 |
+
print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.")
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
# Wait for completion
|
| 285 |
+
print("\nWaiting for jobs to complete...")
|
| 286 |
+
|
| 287 |
+
def _progress(summary: dict) -> None:
|
| 288 |
+
counts = summary["counts"]
|
| 289 |
+
parts = [f"{s}: {n}" for s, n in counts.items()]
|
| 290 |
+
print(f" [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}")
|
| 291 |
+
|
| 292 |
+
monitor.wait_all(poll_interval=args.poll_interval, callback=_progress)
|
| 293 |
+
|
| 294 |
+
# Final summary
|
| 295 |
+
summary = monitor.summary()
|
| 296 |
+
print(f"\nDone. {summary['counts'].get('completed', 0)} completed, "
|
| 297 |
+
f"{summary['counts'].get('failed', 0)} failed.")
|
| 298 |
+
|
| 299 |
+
failed = [r for r in summary["records"] if r["status"] == "failed"]
|
| 300 |
+
if failed:
|
| 301 |
+
print("\nFailed jobs:")
|
| 302 |
+
for r in failed:
|
| 303 |
+
print(f" {r['label']} (job_id={r['job_id']})")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
main()
|