cernenv / scripts /run_agent.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""Run a (non-LLM) baseline agent against the in-process environment.
Usage:
python -m scripts.run_agent --agent heuristic --scenario easy_diphoton_160 --seed 7
python -m scripts.run_agent --agent oracle --difficulty hard --episodes 5
"""
from __future__ import annotations
import argparse
import json
from dataclasses import asdict
from typing import Any, Dict, List
from server.environment import CERNCollisionEnvironment
from scripts.baseline_agents import (
HeuristicAgent,
OracleAgent,
RandomAgent,
)
AGENT_REGISTRY = {
"random": RandomAgent,
"heuristic": HeuristicAgent,
"oracle": OracleAgent,
}
def run_episode(
*,
agent_name: str,
difficulty: str | None,
scenario: str | None,
seed: int,
max_steps: int,
verbose: bool,
) -> Dict[str, Any]:
env = CERNCollisionEnvironment(max_steps=max_steps)
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
agent_cls = AGENT_REGISTRY[agent_name]
if agent_name == "random":
agent = agent_cls(seed=seed)
else:
agent = agent_cls()
if agent_name == "oracle":
agent.truth = env.hidden_truth()
agent.reset()
total_reward = 0.0
step_log: List[Dict[str, Any]] = []
while not obs.done:
action = agent.act(obs)
obs = env.step(action)
total_reward += float(obs.reward or 0.0)
if verbose:
print(
f" step {obs.step_index:2d} {action.action_type.value:24s} "
f"rew={obs.reward:+.3f} done={obs.done}"
)
step_log.append(
{
"step": obs.step_index,
"action": action.action_type.value,
"reward": float(obs.reward or 0.0),
"violations": obs.rule_violations,
}
)
summary = {
"agent": agent_name,
"scenario": env.state.scenario_name,
"difficulty": env.state.difficulty,
"seed": seed,
"total_reward": total_reward,
"cumulative_reward": float(env.state.cumulative_reward),
"terminal_reward": env.state.terminal_reward,
"discovered": env.state.discovered,
"correct_mass": env.state.correct_mass,
"correct_channel": env.state.correct_channel,
"correct_spin": env.state.correct_spin,
"steps": len(step_log),
"truth": env.hidden_truth(),
"log": step_log,
}
return summary
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--agent", choices=list(AGENT_REGISTRY), default="heuristic")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--episodes", type=int, default=1)
parser.add_argument("--max-steps", type=int, default=40)
parser.add_argument("--out", default=None, help="Optional path to dump JSON results")
parser.add_argument("--quiet", action="store_true")
args = parser.parse_args()
rollouts: List[Dict[str, Any]] = []
for ep in range(args.episodes):
seed = args.seed + ep
summary = run_episode(
agent_name=args.agent,
difficulty=args.difficulty,
scenario=args.scenario,
seed=seed,
max_steps=args.max_steps,
verbose=not args.quiet and args.episodes == 1,
)
rollouts.append(summary)
print(
f"[{ep + 1}/{args.episodes}] agent={args.agent} "
f"scenario={summary['scenario']} reward={summary['total_reward']:+.3f} "
f"discovered={summary['discovered']} correct_mass={summary['correct_mass']} "
f"correct_channel={summary['correct_channel']}"
)
if args.out:
with open(args.out, "w") as f:
json.dump(rollouts, f, indent=2, default=str)
print(f"saved → {args.out}")
if __name__ == "__main__":
main()