File size: 4,159 Bytes
f28409b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Run a (non-LLM) baseline agent against the in-process environment.



Usage:

    python -m scripts.run_agent --agent heuristic --scenario easy_diphoton_160 --seed 7

    python -m scripts.run_agent --agent oracle --difficulty hard --episodes 5

"""

from __future__ import annotations

import argparse
import json
from dataclasses import asdict
from typing import Any, Dict, List

from server.environment import CERNCollisionEnvironment
from scripts.baseline_agents import (
    HeuristicAgent,
    OracleAgent,
    RandomAgent,
)


AGENT_REGISTRY = {
    "random": RandomAgent,
    "heuristic": HeuristicAgent,
    "oracle": OracleAgent,
}


def run_episode(

    *,

    agent_name: str,

    difficulty: str | None,

    scenario: str | None,

    seed: int,

    max_steps: int,

    verbose: bool,

) -> Dict[str, Any]:
    env = CERNCollisionEnvironment(max_steps=max_steps)
    obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)

    agent_cls = AGENT_REGISTRY[agent_name]
    if agent_name == "random":
        agent = agent_cls(seed=seed)
    else:
        agent = agent_cls()
    if agent_name == "oracle":
        agent.truth = env.hidden_truth()

    agent.reset()

    total_reward = 0.0
    step_log: List[Dict[str, Any]] = []
    while not obs.done:
        action = agent.act(obs)
        obs = env.step(action)
        total_reward += float(obs.reward or 0.0)
        if verbose:
            print(
                f"  step {obs.step_index:2d}  {action.action_type.value:24s} "
                f"rew={obs.reward:+.3f}  done={obs.done}"
            )
        step_log.append(
            {
                "step": obs.step_index,
                "action": action.action_type.value,
                "reward": float(obs.reward or 0.0),
                "violations": obs.rule_violations,
            }
        )

    summary = {
        "agent": agent_name,
        "scenario": env.state.scenario_name,
        "difficulty": env.state.difficulty,
        "seed": seed,
        "total_reward": total_reward,
        "cumulative_reward": float(env.state.cumulative_reward),
        "terminal_reward": env.state.terminal_reward,
        "discovered": env.state.discovered,
        "correct_mass": env.state.correct_mass,
        "correct_channel": env.state.correct_channel,
        "correct_spin": env.state.correct_spin,
        "steps": len(step_log),
        "truth": env.hidden_truth(),
        "log": step_log,
    }
    return summary


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--agent", choices=list(AGENT_REGISTRY), default="heuristic")
    parser.add_argument("--scenario", default=None)
    parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--episodes", type=int, default=1)
    parser.add_argument("--max-steps", type=int, default=40)
    parser.add_argument("--out", default=None, help="Optional path to dump JSON results")
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args()

    rollouts: List[Dict[str, Any]] = []
    for ep in range(args.episodes):
        seed = args.seed + ep
        summary = run_episode(
            agent_name=args.agent,
            difficulty=args.difficulty,
            scenario=args.scenario,
            seed=seed,
            max_steps=args.max_steps,
            verbose=not args.quiet and args.episodes == 1,
        )
        rollouts.append(summary)
        print(
            f"[{ep + 1}/{args.episodes}] agent={args.agent} "
            f"scenario={summary['scenario']} reward={summary['total_reward']:+.3f} "
            f"discovered={summary['discovered']} correct_mass={summary['correct_mass']} "
            f"correct_channel={summary['correct_channel']}"
        )

    if args.out:
        with open(args.out, "w") as f:
            json.dump(rollouts, f, indent=2, default=str)
        print(f"saved → {args.out}")


if __name__ == "__main__":
    main()