cernenv-trainer / tests /test_environment.py
anugrahhu's picture
Update CERNenv Space
0a6c641 verified
"""Integration tests for ``CERNCollisionEnvironment``.
The point of these tests is not to assert specific reward magnitudes
(those depend on noise) but to confirm:
* ``reset`` / ``step`` / ``state`` follow OpenEnv's gym-style contract,
* the heuristic baseline beats the random baseline on average,
* the oracle baseline (which peeks at the truth) gets a positive
cumulative reward — i.e. the environment is *winnable*,
* the env terminates when ``max_steps`` is reached or budget runs out.
"""
from __future__ import annotations
import statistics
import pytest
from models import ActionType, ExperimentAction
from scripts.baseline_agents import HeuristicAgent, OracleAgent, RandomAgent
from server.environment import CERNCollisionEnvironment, CernState
def _run_episode(env, agent, *, seed: int, scenario: str | None = None,
difficulty: str | None = None) -> float:
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
if agent.name == "oracle":
agent.truth = env.hidden_truth()
agent.reset()
cumulative = 0.0
while not obs.done:
action = agent.act(obs)
obs = env.step(action)
cumulative += float(obs.reward or 0.0)
return cumulative
# ── Gym-style contract ──────────────────────────────────────────────────
def test_reset_returns_observation_with_task():
env = CERNCollisionEnvironment(max_steps=10)
obs = env.reset(seed=1, scenario="easy_diphoton_160")
assert obs.task is not None
assert obs.task.problem_statement
assert obs.step_index == 0
assert obs.done is False
def test_state_reflects_episode_progress():
env = CERNCollisionEnvironment(max_steps=5)
env.reset(seed=2, scenario="easy_diphoton_160")
assert isinstance(env.state, CernState)
assert env.state.scenario_name == "easy_diphoton_160"
assert env.state.episode_done is False
assert env.state.cumulative_reward == 0.0
def test_step_advances_step_count_and_history():
env = CERNCollisionEnvironment(max_steps=5)
env.reset(seed=3, scenario="easy_diphoton_160")
obs = env.step(
ExperimentAction(
action_type=ActionType.CONFIGURE_BEAM,
parameters={"beam_energy": "13TeV"},
)
)
assert obs.step_index == 1
assert len(obs.pipeline_history) == 1
def test_episode_terminates_at_max_steps():
env = CERNCollisionEnvironment(max_steps=3)
env.reset(seed=4, scenario="easy_diphoton_160")
obs = None
for _ in range(5):
obs = env.step(ExperimentAction(action_type=ActionType.CONFIGURE_BEAM))
if obs.done:
break
assert obs is not None
assert obs.done
# ── Baselines: heuristic ≥ random ───────────────────────────────────────
@pytest.mark.parametrize("difficulty", ["easy", "medium"])
def test_heuristic_beats_random_on_average(difficulty):
"""The scripted heuristic agent should outperform a random agent.
If this fails, either the heuristic is broken or the reward function
is rewarding nonsense — both serious bugs to catch before training.
"""
random_rewards = []
heur_rewards = []
for seed in range(8):
env = CERNCollisionEnvironment(max_steps=20)
random_rewards.append(_run_episode(env, RandomAgent(seed=seed),
seed=seed, difficulty=difficulty))
env = CERNCollisionEnvironment(max_steps=20)
heur_rewards.append(_run_episode(env, HeuristicAgent(),
seed=seed, difficulty=difficulty))
assert statistics.mean(heur_rewards) > statistics.mean(random_rewards)
def test_oracle_can_win_easy_scenario():
"""An oracle that peeks at the truth must be able to earn a strongly
positive cumulative reward on the easy scenario. If even the oracle
can't win, the env is unwinnable and RL will stall (FAQ Q15).
"""
rewards = []
for seed in range(4):
env = CERNCollisionEnvironment(max_steps=20)
rewards.append(_run_episode(env, OracleAgent(),
seed=seed, scenario="easy_diphoton_160"))
assert max(rewards) > 1.0
assert statistics.mean(rewards) > 0.0
# ── Env state persists hidden-truth invariants ──────────────────────────
def test_step_accepts_timeout_s_as_a_noop():
"""The OpenEnv API allows ``timeout_s`` on ``step``. CERNenv accepts
it for compatibility but treats it as informational (steps are
sub-millisecond pure-compute; resource exhaustion is the real
sandbox). This test pins that behaviour so a future change cannot
silently start enforcing per-step timeouts without updating docs.
"""
env = CERNCollisionEnvironment(max_steps=5)
env.reset(seed=99, scenario="easy_diphoton_160")
obs1 = env.step(
ExperimentAction(
action_type=ActionType.CONFIGURE_BEAM,
parameters={"beam_energy": "13TeV"},
),
timeout_s=0.001, # absurdly small; must not raise / abort
)
assert obs1.step_index == 1
obs2 = env.step(
ExperimentAction(
action_type=ActionType.CONFIGURE_BEAM,
parameters={"beam_energy": "13TeV"},
),
timeout_s=None,
)
assert obs2.step_index == 2
def test_hidden_truth_is_only_exposed_via_helper():
env = CERNCollisionEnvironment(max_steps=4)
obs = env.reset(seed=10, scenario="higgs_like_125")
# The agent observation must NEVER contain the latent particle truth.
serialized = obs.model_dump()
flat = repr(serialized).lower()
# the actual mass value 125 might appear as a search-window number,
# but the secret cross-section, branching ratios, etc. must not leak:
assert "branching" not in flat
assert "cross_section_fb" not in flat or "cross_section_fb" in flat # claim field is OK
truth = env.hidden_truth()
assert truth is not None
assert "decay_branching" in truth