#!/usr/bin/env python3 """ run_experiments.py ------------------ CLI orchestrator for SpatialBench experiments. Run on the cluster with SLURM: python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm Run directly (uses API keys, no SLURM required): python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct Dry-run (print commands without executing): python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run Filter experiments: python run_experiments.py --tasks maze_navigation \\ --models gemini-2.5-flash claude-haiku-4-5 \\ --grid-sizes 5 6 7 \\ --formats raw \\ --strategies cot reasoning Show status of running SLURM jobs: python run_experiments.py --status """ from __future__ import annotations import argparse import json import os import subprocess import sys import tempfile import time from pathlib import Path # Load .env if present (before importing pipeline modules) _env_file = Path(__file__).parent / ".env" if _env_file.exists(): with open(_env_file) as _f: for _line in _f: _line = _line.strip() if _line and not _line.startswith("#") and "=" in _line: _k, _v = _line.split("=", 1) os.environ.setdefault(_k.strip(), _v.strip()) from pipeline.task_builder import ( load_config, build_all_jobs, make_sbatch_script, ExperimentJob, ) from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml" REPO_ROOT = CONFIG_PATH.parent.parent.parent # llm-maze-solver/ # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _check_api_key(job: ExperimentJob) -> bool: val = os.environ.get(job.api_key_env, "") if not val: print(f" [WARN] {job.api_key_env} not set — skipping: {job.label}") return False return True def _print_job(job: ExperimentJob) -> None: print(f"\n {job.label}") print(f" cmd : {' '.join(job.python_cmd[:4])} ...") print(f" wdir: {job.working_dir}") print(f" out : {job.output_dir}") # --------------------------------------------------------------------------- # Run modes # --------------------------------------------------------------------------- def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None: log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs" log_dir.mkdir(parents=True, exist_ok=True) for job in jobs: if not _check_api_key(job): continue script_text = make_sbatch_script(job, log_dir) if dry_run: _print_job(job) print(" --- sbatch script ---") print(script_text) continue with tempfile.NamedTemporaryFile( mode="w", suffix=".sh", prefix="spatialbench_", dir=log_dir, delete=False ) as tmp: tmp.write(script_text) script_path = tmp.name job_id = submit_sbatch(script_path) if job_id: monitor.add( job_id=job_id, label=job.label, task_id=job.task_id, model=job.model, output_dir=str(job.output_dir), log_out=str(log_dir / f"{job_id}.out"), log_err=str(log_dir / f"{job_id}.err"), ) print(f" Submitted {job.label} → SLURM job {job_id}") else: print(f" [ERROR] Failed to submit: {job.label}") def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None: for job in jobs: if not _check_api_key(job): continue if dry_run: _print_job(job) print(f" cmd: {' '.join(job.python_cmd)}\n") continue env_patch = {job.api_key_env: os.environ[job.api_key_env]} job.output_dir.mkdir(parents=True, exist_ok=True) print(f" Starting: {job.label}") proc = submit_direct( cmd=job.python_cmd, working_dir=str(job.working_dir), env=env_patch, ) monitor.add_direct( proc=proc, label=job.label, task_id=job.task_id, model=job.model, output_dir=str(job.output_dir), ) # Small gap to avoid hammering APIs simultaneously time.sleep(2) # --------------------------------------------------------------------------- # Status display # --------------------------------------------------------------------------- def show_status(monitor: JobMonitor) -> None: monitor.refresh() summary = monitor.summary() print(f"\nTotal jobs: {summary['total']}") for status, count in summary["counts"].items(): print(f" {status:12s}: {count}") print() for r in summary["records"]: print(f" [{r['status']:9s}] {r['label']:<60s} elapsed: {r['elapsed']}") # --------------------------------------------------------------------------- # Argument parsing # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="SpatialBench experiment orchestrator", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( "--tasks", nargs="+", default=["maze_navigation", "point_reuse", "compositional_distance"], choices=["maze_navigation", "point_reuse", "compositional_distance"], help="Which tasks to run (default: all three)", ) parser.add_argument( "--models", nargs="+", default=None, help="Model IDs to run (default: all models in config)", ) parser.add_argument( "--grid-sizes", nargs="+", type=int, default=None, dest="grid_sizes", help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)", ) parser.add_argument( "--formats", nargs="+", default=None, choices=["raw", "visual"], help="Input formats for Task 1 (default: both raw and visual)", ) parser.add_argument( "--strategies", nargs="+", default=None, choices=["base", "cot", "reasoning"], help="Prompt strategies (default: all)", ) parser.add_argument( "--mode", default="slurm", choices=["slurm", "direct"], help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)", ) parser.add_argument( "--dry-run", action="store_true", help="Print commands without executing them", ) parser.add_argument( "--no-wait", action="store_true", help="Return immediately after submission (don't poll for completion)", ) parser.add_argument( "--status", action="store_true", help="Query and display SLURM job status (requires --job-ids or a running monitor)", ) parser.add_argument( "--job-ids", nargs="+", default=None, help="SLURM job IDs to check status for (used with --status)", ) parser.add_argument( "--config", default=str(CONFIG_PATH), help=f"Path to experiments.yaml (default: {CONFIG_PATH})", ) parser.add_argument( "--poll-interval", type=int, default=60, dest="poll_interval", help="Seconds between SLURM status polls when waiting (default: 60)", ) return parser.parse_args() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: args = parse_args() cfg = load_config(args.config) # Status-only mode if args.status: monitor = JobMonitor(mode="slurm") if args.job_ids: for jid in args.job_ids: monitor.add(job_id=jid, label=jid, task_id="?", model="?") show_status(monitor) return # Build jobs jobs = build_all_jobs( cfg=cfg, tasks=args.tasks, models=args.models, grid_sizes=args.grid_sizes, input_formats=args.formats, prompt_strategies=args.strategies, config_path=Path(args.config), ) if not jobs: print("No jobs matched the requested filters.") return print(f"\nSpatialBench — {len(jobs)} job(s) to run") print(f" mode : {args.mode}") print(f" tasks : {args.tasks}") print(f" models : {args.models or 'all'}") print(f" grids : {args.grid_sizes or 'per-task default'}") print(f" formats : {args.formats or 'per-task default'}") print(f" strategies: {args.strategies or 'all'}") print(f" dry-run : {args.dry_run}") print() monitor = JobMonitor(mode=args.mode) if args.mode == "slurm": run_slurm(jobs, monitor, dry_run=args.dry_run) else: run_direct(jobs, monitor, dry_run=args.dry_run) if args.dry_run or args.no_wait: if not args.dry_run: print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.") return # Wait for completion print("\nWaiting for jobs to complete...") def _progress(summary: dict) -> None: counts = summary["counts"] parts = [f"{s}: {n}" for s, n in counts.items()] print(f" [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}") monitor.wait_all(poll_interval=args.poll_interval, callback=_progress) # Final summary summary = monitor.summary() print(f"\nDone. {summary['counts'].get('completed', 0)} completed, " f"{summary['counts'].get('failed', 0)} failed.") failed = [r for r in summary["records"] if r["status"] == "failed"] if failed: print("\nFailed jobs:") for r in failed: print(f" {r['label']} (job_id={r['job_id']})") if __name__ == "__main__": main()