"""``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards. This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes. It owns one episode at a time: reset(seed) → builds a fresh latent state from a sampled scenario. step(action) → validates → generates noisy output → updates state → computes reward → builds the agent observation. """ from __future__ import annotations import logging import uuid from typing import Any, List, Optional from openenv.core.env_server import Environment, State from models import ( AGENT_ENVIRONMENT_RULES, ActionType, CollisionObservation, DiscoveryClaim, ExperimentAction, IntermediateOutput, OutputType, PipelineStepRecord, ResourceUsage, TaskSpec, build_agent_system_prompt, ) from server.rewards import ( RewardWeights, compute_step_reward, compute_terminal_reward, ) from server.rules import RulesEngine, ViolationCode from server.simulator import ( NoiseModel, OutputGenerator, TransitionEngine, compute_action_cost, ) from server.simulator.latent_state import FullLatentState from server.tasks import sample_scenario, Scenario logger = logging.getLogger(__name__) # ── State container ────────────────────────────────────────────────────── class CernState(State): """OpenEnv State subclass: includes hidden truth & runtime stats.""" scenario_name: Optional[str] = None difficulty: Optional[str] = None episode_done: bool = False cumulative_reward: float = 0.0 terminal_reward: Optional[float] = None discovered: Optional[bool] = None correct_mass: Optional[bool] = None correct_channel: Optional[bool] = None correct_spin: Optional[bool] = None truth_mass_gev: Optional[float] = None truth_channel: Optional[str] = None # ── Environment ────────────────────────────────────────────────────────── class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]): """LHC particle-discovery POMDP environment.""" SUPPORTS_CONCURRENT_SESSIONS = True def __init__( self, *, max_steps: int = 40, default_difficulty: Optional[str] = None, default_scenario_name: Optional[str] = None, reward_weights: Optional[RewardWeights] = None, ) -> None: super().__init__() self.max_steps = max_steps self.default_difficulty = default_difficulty self.default_scenario_name = default_scenario_name self.reward_weights = reward_weights or RewardWeights() self._state = CernState() self._scenario: Optional[Scenario] = None self._latent: Optional[FullLatentState] = None self._task: Optional[TaskSpec] = None self._noise: Optional[NoiseModel] = None self._output_gen: Optional[OutputGenerator] = None self._transition: Optional[TransitionEngine] = None self._rules: Optional[RulesEngine] = None self._history: List[PipelineStepRecord] = [] self._all_outputs: List[IntermediateOutput] = [] # ── Environment API ──────────────────────────────────────────────── @property def state(self) -> CernState: return self._state def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> CollisionObservation: difficulty = kwargs.get("difficulty") or self.default_difficulty scenario_name = kwargs.get("scenario") or self.default_scenario_name scenario = sample_scenario( difficulty=difficulty, name=scenario_name, seed=seed, ) self._scenario = scenario self._latent = scenario.fresh_latent() self._task = scenario.task if seed is not None: self._latent.rng_seed = int(seed) self._noise = NoiseModel(seed=self._latent.rng_seed) self._output_gen = OutputGenerator(self._noise) self._transition = TransitionEngine() self._rules = RulesEngine( mass_search_window_gev=tuple(self._task.mass_search_window_gev), ) self._history = [] self._all_outputs = [] self._state = CernState( episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}", step_count=0, scenario_name=scenario.name, difficulty=scenario.difficulty, episode_done=False, cumulative_reward=0.0, truth_mass_gev=self._latent.particle.mass_gev, truth_channel=self._latent.particle.primary_channel, ) obs = self._build_observation( latest_output=None, done=False, reward=0.0, step_breakdown={}, rule_violations=[], ) return obs def step( self, action: ExperimentAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> CollisionObservation: if self._latent is None: self.reset() if self._state.episode_done: return self._build_terminal_observation(reason="episode already complete") assert self._rules is not None assert self._output_gen is not None assert self._transition is not None prev_state = self._latent.model_copy(deep=True) rule_result = self._rules.validate(action, self._latent) if not rule_result.allowed: output = IntermediateOutput( output_type=OutputType.FAILURE_REPORT, step_index=self._state.step_count, success=False, quality_score=0.0, summary="Action rejected: " + "; ".join(rule_result.messages), warnings=rule_result.messages, ) else: output = self._output_gen.generate( action=action, state=self._latent, step_index=self._state.step_count, ) # Apply transition (state mutation + cost accounting) if rule_result.allowed: self._transition.step(self._latent, action, output) else: cost = compute_action_cost(action, output) self._latent.resources.budget_used_musd += cost["musd"] self._latent.resources.time_used_days += cost["days"] self._latent.step_count += 1 self._all_outputs.append(output) cost = compute_action_cost(action, output) record = PipelineStepRecord( step_index=self._state.step_count, action_type=action.action_type, method=action.method, parameters=action.parameters, output_summary=output.summary, output_type=output.output_type, success=output.success, quality_score=float(output.quality_score), cost_musd=float(cost["musd"]), luminosity_cost_fb=float(cost["luminosity_fb"]), time_cost_days=float(cost["days"]), ) self._history.append(record) step_reward = compute_step_reward( action=action, output=output, state_before=prev_state, state_after=self._latent, rule_result=rule_result, weights=self.reward_weights, ) self._state.cumulative_reward += step_reward.reward self._state.step_count += 1 terminal_now = ( action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM and rule_result.allowed ) time_up = ( self._state.step_count >= self.max_steps or self._latent.resources.budget_exhausted or self._latent.resources.time_exhausted ) terminal_reward_value = 0.0 if terminal_now: claim = self._claim_from_action(action) term = compute_terminal_reward( state=self._latent, claim=claim, weights=self.reward_weights, ) terminal_reward_value = term.reward self._state.cumulative_reward += terminal_reward_value self._state.terminal_reward = terminal_reward_value self._state.discovered = term.discovered self._state.correct_mass = term.correct_mass self._state.correct_channel = term.correct_channel self._state.correct_spin = term.correct_spin done = terminal_now or time_up if done: self._state.episode_done = True observation = self._build_observation( latest_output=output, done=done, reward=step_reward.reward + terminal_reward_value, step_breakdown=step_reward.breakdown.components, rule_violations=[ *(v.value for v in rule_result.violations), *(v.value for v in rule_result.soft_violations), ], ) return observation # ── Helpers ──────────────────────────────────────────────────────── def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim: raw = action.parameters.get("claim") or {} try: return DiscoveryClaim(**raw) except Exception as exc: # pragma: no cover - defensive logger.warning("Malformed claim, defaulting: %s", exc) return DiscoveryClaim() def _build_terminal_observation(self, reason: str) -> CollisionObservation: obs = self._build_observation( latest_output=IntermediateOutput( output_type=OutputType.FAILURE_REPORT, step_index=self._state.step_count, success=False, summary=reason, ), done=True, reward=0.0, step_breakdown={}, rule_violations=["episode_terminated"], ) return obs def _build_observation( self, *, latest_output: Optional[IntermediateOutput], done: bool, reward: float, step_breakdown: dict, rule_violations: list, ) -> CollisionObservation: assert self._latent is not None assert self._task is not None res = self._latent.resources usage = ResourceUsage( budget_used_musd=res.budget_used_musd, budget_remaining_musd=res.budget_remaining, luminosity_used_fb=res.luminosity_used_fb, luminosity_remaining_fb=res.luminosity_remaining, time_used_days=res.time_used_days, time_remaining_days=res.time_remaining, compute_hours_used=res.compute_hours_used, ) obs = CollisionObservation( done=done, reward=float(reward), task=self._task, step_index=self._state.step_count, pipeline_history=list(self._history), available_channels=self._task.available_channels, available_triggers=self._task.available_triggers, available_tools=self._task.available_tools, resource_usage=usage, latest_output=latest_output, all_outputs=list(self._all_outputs), candidate_masses_gev=list(self._latent.candidate_masses_gev), candidate_significances=list(self._latent.candidate_significances), selected_channel=self._latent.selected_channel, selected_beam_energy=self._latent.selected_beam_energy, cumulative_significance=float( self._latent.progress.best_significance_sigma or 0.0 ), uncertainty_summary={ "energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty, "luminosity_unc": self._latent.detector.luminosity_uncertainty, "resolution_gev": self._latent.detector.detector_resolution_gev, }, rule_violations=rule_violations, step_reward_breakdown=dict(step_breakdown), ) return obs # ── Convenience for diagnostics ──────────────────────────────────── def hidden_truth(self) -> Optional[dict]: """Reveal the hidden particle (debug / evaluation only).""" if self._latent is None: return None return self._latent.particle.model_dump() __all__ = [ "CernState", "CERNCollisionEnvironment", "AGENT_ENVIRONMENT_RULES", "build_agent_system_prompt", ]