"""Run a (non-LLM) baseline agent against the in-process environment. Usage: python -m scripts.run_agent --agent heuristic --scenario easy_diphoton_160 --seed 7 python -m scripts.run_agent --agent oracle --difficulty hard --episodes 5 """ from __future__ import annotations import argparse import json from dataclasses import asdict from typing import Any, Dict, List from server.environment import CERNCollisionEnvironment from scripts.baseline_agents import ( HeuristicAgent, OracleAgent, RandomAgent, ) AGENT_REGISTRY = { "random": RandomAgent, "heuristic": HeuristicAgent, "oracle": OracleAgent, } def run_episode( *, agent_name: str, difficulty: str | None, scenario: str | None, seed: int, max_steps: int, verbose: bool, ) -> Dict[str, Any]: env = CERNCollisionEnvironment(max_steps=max_steps) obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty) agent_cls = AGENT_REGISTRY[agent_name] if agent_name == "random": agent = agent_cls(seed=seed) else: agent = agent_cls() if agent_name == "oracle": agent.truth = env.hidden_truth() agent.reset() total_reward = 0.0 step_log: List[Dict[str, Any]] = [] while not obs.done: action = agent.act(obs) obs = env.step(action) total_reward += float(obs.reward or 0.0) if verbose: print( f" step {obs.step_index:2d} {action.action_type.value:24s} " f"rew={obs.reward:+.3f} done={obs.done}" ) step_log.append( { "step": obs.step_index, "action": action.action_type.value, "reward": float(obs.reward or 0.0), "violations": obs.rule_violations, } ) summary = { "agent": agent_name, "scenario": env.state.scenario_name, "difficulty": env.state.difficulty, "seed": seed, "total_reward": total_reward, "cumulative_reward": float(env.state.cumulative_reward), "terminal_reward": env.state.terminal_reward, "discovered": env.state.discovered, "correct_mass": env.state.correct_mass, "correct_channel": env.state.correct_channel, "correct_spin": env.state.correct_spin, "steps": len(step_log), "truth": env.hidden_truth(), "log": step_log, } return summary def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--agent", choices=list(AGENT_REGISTRY), default="heuristic") parser.add_argument("--scenario", default=None) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--episodes", type=int, default=1) parser.add_argument("--max-steps", type=int, default=40) parser.add_argument("--out", default=None, help="Optional path to dump JSON results") parser.add_argument("--quiet", action="store_true") args = parser.parse_args() rollouts: List[Dict[str, Any]] = [] for ep in range(args.episodes): seed = args.seed + ep summary = run_episode( agent_name=args.agent, difficulty=args.difficulty, scenario=args.scenario, seed=seed, max_steps=args.max_steps, verbose=not args.quiet and args.episodes == 1, ) rollouts.append(summary) print( f"[{ep + 1}/{args.episodes}] agent={args.agent} " f"scenario={summary['scenario']} reward={summary['total_reward']:+.3f} " f"discovered={summary['discovered']} correct_mass={summary['correct_mass']} " f"correct_channel={summary['correct_channel']}" ) if args.out: with open(args.out, "w") as f: json.dump(rollouts, f, indent=2, default=str) print(f"saved → {args.out}") if __name__ == "__main__": main()