Spaces:
Sleeping
Sleeping
| """Run a (non-LLM) baseline agent against the in-process environment. | |
| Usage: | |
| python -m scripts.run_agent --agent heuristic --scenario easy_diphoton_160 --seed 7 | |
| python -m scripts.run_agent --agent oracle --difficulty hard --episodes 5 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from dataclasses import asdict | |
| from typing import Any, Dict, List | |
| from server.environment import CERNCollisionEnvironment | |
| from scripts.baseline_agents import ( | |
| HeuristicAgent, | |
| OracleAgent, | |
| RandomAgent, | |
| ) | |
| AGENT_REGISTRY = { | |
| "random": RandomAgent, | |
| "heuristic": HeuristicAgent, | |
| "oracle": OracleAgent, | |
| } | |
| def run_episode( | |
| *, | |
| agent_name: str, | |
| difficulty: str | None, | |
| scenario: str | None, | |
| seed: int, | |
| max_steps: int, | |
| verbose: bool, | |
| ) -> Dict[str, Any]: | |
| env = CERNCollisionEnvironment(max_steps=max_steps) | |
| obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty) | |
| agent_cls = AGENT_REGISTRY[agent_name] | |
| if agent_name == "random": | |
| agent = agent_cls(seed=seed) | |
| else: | |
| agent = agent_cls() | |
| if agent_name == "oracle": | |
| agent.truth = env.hidden_truth() | |
| agent.reset() | |
| total_reward = 0.0 | |
| step_log: List[Dict[str, Any]] = [] | |
| while not obs.done: | |
| action = agent.act(obs) | |
| obs = env.step(action) | |
| total_reward += float(obs.reward or 0.0) | |
| if verbose: | |
| print( | |
| f" step {obs.step_index:2d} {action.action_type.value:24s} " | |
| f"rew={obs.reward:+.3f} done={obs.done}" | |
| ) | |
| step_log.append( | |
| { | |
| "step": obs.step_index, | |
| "action": action.action_type.value, | |
| "reward": float(obs.reward or 0.0), | |
| "violations": obs.rule_violations, | |
| } | |
| ) | |
| summary = { | |
| "agent": agent_name, | |
| "scenario": env.state.scenario_name, | |
| "difficulty": env.state.difficulty, | |
| "seed": seed, | |
| "total_reward": total_reward, | |
| "cumulative_reward": float(env.state.cumulative_reward), | |
| "terminal_reward": env.state.terminal_reward, | |
| "discovered": env.state.discovered, | |
| "correct_mass": env.state.correct_mass, | |
| "correct_channel": env.state.correct_channel, | |
| "correct_spin": env.state.correct_spin, | |
| "steps": len(step_log), | |
| "truth": env.hidden_truth(), | |
| "log": step_log, | |
| } | |
| return summary | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--agent", choices=list(AGENT_REGISTRY), default="heuristic") | |
| parser.add_argument("--scenario", default=None) | |
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument("--episodes", type=int, default=1) | |
| parser.add_argument("--max-steps", type=int, default=40) | |
| parser.add_argument("--out", default=None, help="Optional path to dump JSON results") | |
| parser.add_argument("--quiet", action="store_true") | |
| args = parser.parse_args() | |
| rollouts: List[Dict[str, Any]] = [] | |
| for ep in range(args.episodes): | |
| seed = args.seed + ep | |
| summary = run_episode( | |
| agent_name=args.agent, | |
| difficulty=args.difficulty, | |
| scenario=args.scenario, | |
| seed=seed, | |
| max_steps=args.max_steps, | |
| verbose=not args.quiet and args.episodes == 1, | |
| ) | |
| rollouts.append(summary) | |
| print( | |
| f"[{ep + 1}/{args.episodes}] agent={args.agent} " | |
| f"scenario={summary['scenario']} reward={summary['total_reward']:+.3f} " | |
| f"discovered={summary['discovered']} correct_mass={summary['correct_mass']} " | |
| f"correct_channel={summary['correct_channel']}" | |
| ) | |
| if args.out: | |
| with open(args.out, "w") as f: | |
| json.dump(rollouts, f, indent=2, default=str) | |
| print(f"saved → {args.out}") | |
| if __name__ == "__main__": | |
| main() | |