Spaces:
Paused
Paused
| """``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards. | |
| This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes. | |
| It owns one episode at a time: | |
| reset(seed) → builds a fresh latent state from a sampled scenario. | |
| step(action) → validates → generates noisy output → updates state → | |
| computes reward → builds the agent observation. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import uuid | |
| from typing import Any, List, Optional | |
| from openenv.core.env_server import Environment | |
| from models import ( | |
| AGENT_ENVIRONMENT_RULES, | |
| ActionType, | |
| CernState, | |
| CollisionObservation, | |
| DiscoveryClaim, | |
| ExperimentAction, | |
| IntermediateOutput, | |
| OutputType, | |
| PipelineStepRecord, | |
| ResourceUsage, | |
| TaskSpec, | |
| build_agent_system_prompt, | |
| ) | |
| from server.rewards import ( | |
| RewardWeights, | |
| compute_step_reward, | |
| compute_terminal_reward, | |
| ) | |
| from server.rules import RulesEngine, ViolationCode | |
| from server.simulator import ( | |
| NoiseModel, | |
| OutputGenerator, | |
| TransitionEngine, | |
| compute_action_cost, | |
| ) | |
| from server.simulator.latent_state import FullLatentState | |
| from server.tasks import sample_scenario, Scenario | |
| logger = logging.getLogger(__name__) | |
| # ── Environment ────────────────────────────────────────────────────────── | |
| class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]): | |
| """LHC particle-discovery POMDP environment.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__( | |
| self, | |
| *, | |
| max_steps: int = 40, | |
| default_difficulty: Optional[str] = None, | |
| default_scenario_name: Optional[str] = None, | |
| reward_weights: Optional[RewardWeights] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.max_steps = max_steps | |
| self.default_difficulty = default_difficulty | |
| self.default_scenario_name = default_scenario_name | |
| self.reward_weights = reward_weights or RewardWeights() | |
| self._state = CernState() | |
| self._scenario: Optional[Scenario] = None | |
| self._latent: Optional[FullLatentState] = None | |
| self._task: Optional[TaskSpec] = None | |
| self._noise: Optional[NoiseModel] = None | |
| self._output_gen: Optional[OutputGenerator] = None | |
| self._transition: Optional[TransitionEngine] = None | |
| self._rules: Optional[RulesEngine] = None | |
| self._history: List[PipelineStepRecord] = [] | |
| self._all_outputs: List[IntermediateOutput] = [] | |
| # ── Environment API ──────────────────────────────────────────────── | |
| def state(self) -> CernState: | |
| return self._state | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> CollisionObservation: | |
| difficulty = kwargs.get("difficulty") or self.default_difficulty | |
| scenario_name = kwargs.get("scenario") or self.default_scenario_name | |
| scenario = sample_scenario( | |
| difficulty=difficulty, | |
| name=scenario_name, | |
| seed=seed, | |
| ) | |
| self._scenario = scenario | |
| self._latent = scenario.fresh_latent() | |
| self._task = scenario.task | |
| if seed is not None: | |
| self._latent.rng_seed = int(seed) | |
| self._noise = NoiseModel(seed=self._latent.rng_seed) | |
| self._output_gen = OutputGenerator(self._noise) | |
| self._transition = TransitionEngine() | |
| self._rules = RulesEngine( | |
| mass_search_window_gev=tuple(self._task.mass_search_window_gev), | |
| ) | |
| self._history = [] | |
| self._all_outputs = [] | |
| self._state = CernState( | |
| episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}", | |
| step_count=0, | |
| scenario_name=scenario.name, | |
| difficulty=scenario.difficulty, | |
| episode_done=False, | |
| cumulative_reward=0.0, | |
| truth_mass_gev=self._latent.particle.mass_gev, | |
| truth_channel=self._latent.particle.primary_channel, | |
| ) | |
| obs = self._build_observation( | |
| latest_output=None, | |
| done=False, | |
| reward=0.0, | |
| step_breakdown={}, | |
| rule_violations=[], | |
| ) | |
| return obs | |
| def step( | |
| self, | |
| action: ExperimentAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CollisionObservation: | |
| """Apply one action and return the next observation. | |
| ``timeout_s`` is accepted for OpenEnv API compatibility but is a | |
| no-op for this environment: each ``step`` is pure-compute (numpy | |
| ops on a small latent state, sub-millisecond) and cannot hang. | |
| The episode-level "sandbox" enforced here is *resource* exhaustion | |
| — budget (M$), integrated luminosity (fb⁻¹), and wall-time-days — | |
| which is checked at the bottom of this method and terminates the | |
| episode via ``done=True`` when any limit is crossed. That is the | |
| meaningful timeout for an LHC-discovery rollout. | |
| If ``timeout_s`` is non-None we log it once at debug level so | |
| callers can confirm their value is being received without changing | |
| any runtime behaviour. | |
| """ | |
| if timeout_s is not None: | |
| logger.debug( | |
| "step() received timeout_s=%.3fs (informational; " | |
| "actual cutoff is resource-exhaustion based)", | |
| float(timeout_s), | |
| ) | |
| if self._latent is None: | |
| self.reset() | |
| if self._state.episode_done: | |
| return self._build_terminal_observation(reason="episode already complete") | |
| assert self._rules is not None | |
| assert self._output_gen is not None | |
| assert self._transition is not None | |
| prev_state = self._latent.model_copy(deep=True) | |
| rule_result = self._rules.validate(action, self._latent) | |
| if not rule_result.allowed: | |
| output = IntermediateOutput( | |
| output_type=OutputType.FAILURE_REPORT, | |
| step_index=self._state.step_count, | |
| success=False, | |
| quality_score=0.0, | |
| summary="Action rejected: " + "; ".join(rule_result.messages), | |
| warnings=rule_result.messages, | |
| ) | |
| else: | |
| output = self._output_gen.generate( | |
| action=action, | |
| state=self._latent, | |
| step_index=self._state.step_count, | |
| ) | |
| # Apply transition (state mutation + cost accounting) | |
| if rule_result.allowed: | |
| self._transition.step(self._latent, action, output) | |
| else: | |
| cost = compute_action_cost(action, output) | |
| self._latent.resources.budget_used_musd += cost["musd"] | |
| self._latent.resources.time_used_days += cost["days"] | |
| self._latent.step_count += 1 | |
| self._all_outputs.append(output) | |
| cost = compute_action_cost(action, output) | |
| record = PipelineStepRecord( | |
| step_index=self._state.step_count, | |
| action_type=action.action_type, | |
| method=action.method, | |
| parameters=action.parameters, | |
| output_summary=output.summary, | |
| output_type=output.output_type, | |
| success=output.success, | |
| quality_score=float(output.quality_score), | |
| cost_musd=float(cost["musd"]), | |
| luminosity_cost_fb=float(cost["luminosity_fb"]), | |
| time_cost_days=float(cost["days"]), | |
| ) | |
| self._history.append(record) | |
| step_reward = compute_step_reward( | |
| action=action, | |
| output=output, | |
| state_before=prev_state, | |
| state_after=self._latent, | |
| rule_result=rule_result, | |
| weights=self.reward_weights, | |
| history=self._history[:-1], # exclude the record we just appended | |
| ) | |
| self._state.cumulative_reward += step_reward.reward | |
| self._state.step_count += 1 | |
| terminal_now = ( | |
| action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM | |
| and rule_result.allowed | |
| ) | |
| time_up = ( | |
| self._state.step_count >= self.max_steps | |
| or self._latent.resources.budget_exhausted | |
| or self._latent.resources.time_exhausted | |
| ) | |
| terminal_reward_value = 0.0 | |
| if terminal_now: | |
| claim = self._claim_from_action(action) | |
| term = compute_terminal_reward( | |
| state=self._latent, | |
| claim=claim, | |
| weights=self.reward_weights, | |
| ) | |
| terminal_reward_value = term.reward | |
| self._state.cumulative_reward += terminal_reward_value | |
| self._state.terminal_reward = terminal_reward_value | |
| self._state.discovered = term.discovered | |
| self._state.correct_mass = term.correct_mass | |
| self._state.correct_channel = term.correct_channel | |
| self._state.correct_spin = term.correct_spin | |
| elif time_up: | |
| # Fix #1: if the episode runs out of steps/budget/time and the | |
| # agent never even *attempted* a SUBMIT_DISCOVERY_CLAIM, levy a | |
| # flat no-claim penalty so claim-avoidance can no longer | |
| # dominate the per-step shaping reward (the v1 reward hack). | |
| ever_claimed = any( | |
| rec.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM | |
| for rec in self._history | |
| ) | |
| if not ever_claimed: | |
| term = compute_terminal_reward( | |
| state=self._latent, | |
| claim=None, | |
| weights=self.reward_weights, | |
| ) | |
| terminal_reward_value = term.reward | |
| self._state.cumulative_reward += terminal_reward_value | |
| self._state.terminal_reward = terminal_reward_value | |
| self._state.discovered = term.discovered | |
| self._state.correct_mass = term.correct_mass | |
| self._state.correct_channel = term.correct_channel | |
| self._state.correct_spin = term.correct_spin | |
| done = terminal_now or time_up | |
| if done: | |
| self._state.episode_done = True | |
| observation = self._build_observation( | |
| latest_output=output, | |
| done=done, | |
| reward=step_reward.reward + terminal_reward_value, | |
| step_breakdown=step_reward.breakdown.components, | |
| rule_violations=[ | |
| *(v.value for v in rule_result.violations), | |
| *(v.value for v in rule_result.soft_violations), | |
| ], | |
| ) | |
| return observation | |
| # ── Helpers ──────────────────────────────────────────────────────── | |
| def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim: | |
| raw = action.parameters.get("claim") or {} | |
| try: | |
| return DiscoveryClaim(**raw) | |
| except Exception as exc: # pragma: no cover - defensive | |
| logger.warning("Malformed claim, defaulting: %s", exc) | |
| return DiscoveryClaim() | |
| def _build_terminal_observation(self, reason: str) -> CollisionObservation: | |
| obs = self._build_observation( | |
| latest_output=IntermediateOutput( | |
| output_type=OutputType.FAILURE_REPORT, | |
| step_index=self._state.step_count, | |
| success=False, | |
| summary=reason, | |
| ), | |
| done=True, | |
| reward=0.0, | |
| step_breakdown={}, | |
| rule_violations=["episode_terminated"], | |
| ) | |
| return obs | |
| def _build_observation( | |
| self, | |
| *, | |
| latest_output: Optional[IntermediateOutput], | |
| done: bool, | |
| reward: float, | |
| step_breakdown: dict, | |
| rule_violations: list, | |
| ) -> CollisionObservation: | |
| assert self._latent is not None | |
| assert self._task is not None | |
| res = self._latent.resources | |
| usage = ResourceUsage( | |
| budget_used_musd=res.budget_used_musd, | |
| budget_remaining_musd=res.budget_remaining, | |
| luminosity_used_fb=res.luminosity_used_fb, | |
| luminosity_remaining_fb=res.luminosity_remaining, | |
| time_used_days=res.time_used_days, | |
| time_remaining_days=res.time_remaining, | |
| compute_hours_used=res.compute_hours_used, | |
| ) | |
| obs = CollisionObservation( | |
| done=done, | |
| reward=float(reward), | |
| task=self._task, | |
| step_index=self._state.step_count, | |
| pipeline_history=list(self._history), | |
| available_channels=self._task.available_channels, | |
| available_triggers=self._task.available_triggers, | |
| available_tools=self._task.available_tools, | |
| resource_usage=usage, | |
| latest_output=latest_output, | |
| all_outputs=list(self._all_outputs), | |
| candidate_masses_gev=list(self._latent.candidate_masses_gev), | |
| candidate_significances=list(self._latent.candidate_significances), | |
| selected_channel=self._latent.selected_channel, | |
| selected_beam_energy=self._latent.selected_beam_energy, | |
| cumulative_significance=float( | |
| self._latent.progress.best_significance_sigma or 0.0 | |
| ), | |
| uncertainty_summary={ | |
| "energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty, | |
| "luminosity_unc": self._latent.detector.luminosity_uncertainty, | |
| "resolution_gev": self._latent.detector.detector_resolution_gev, | |
| }, | |
| rule_violations=rule_violations, | |
| step_reward_breakdown=dict(step_breakdown), | |
| ) | |
| return obs | |
| # ── Convenience for diagnostics ──────────────────────────────────── | |
| def hidden_truth(self) -> Optional[dict]: | |
| """Reveal the hidden particle (debug / evaluation only).""" | |
| if self._latent is None: | |
| return None | |
| return self._latent.particle.model_dump() | |
| __all__ = [ | |
| "CernState", | |
| "CERNCollisionEnvironment", | |
| "AGENT_ENVIRONMENT_RULES", | |
| "build_agent_system_prompt", | |
| ] | |