| """``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards.
|
|
|
| This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes.
|
| It owns one episode at a time:
|
|
|
| reset(seed) → builds a fresh latent state from a sampled scenario.
|
| step(action) → validates → generates noisy output → updates state →
|
| computes reward → builds the agent observation.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| import uuid
|
| from typing import Any, List, Optional
|
|
|
| from openenv.core.env_server import Environment, State
|
|
|
| from models import (
|
| AGENT_ENVIRONMENT_RULES,
|
| ActionType,
|
| CollisionObservation,
|
| DiscoveryClaim,
|
| ExperimentAction,
|
| IntermediateOutput,
|
| OutputType,
|
| PipelineStepRecord,
|
| ResourceUsage,
|
| TaskSpec,
|
| build_agent_system_prompt,
|
| )
|
|
|
| from server.rewards import (
|
| RewardWeights,
|
| compute_step_reward,
|
| compute_terminal_reward,
|
| )
|
| from server.rules import RulesEngine, ViolationCode
|
| from server.simulator import (
|
| NoiseModel,
|
| OutputGenerator,
|
| TransitionEngine,
|
| compute_action_cost,
|
| )
|
| from server.simulator.latent_state import FullLatentState
|
| from server.tasks import sample_scenario, Scenario
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
| class CernState(State):
|
| """OpenEnv State subclass: includes hidden truth & runtime stats."""
|
|
|
| scenario_name: Optional[str] = None
|
| difficulty: Optional[str] = None
|
| episode_done: bool = False
|
| cumulative_reward: float = 0.0
|
| terminal_reward: Optional[float] = None
|
| discovered: Optional[bool] = None
|
| correct_mass: Optional[bool] = None
|
| correct_channel: Optional[bool] = None
|
| correct_spin: Optional[bool] = None
|
| truth_mass_gev: Optional[float] = None
|
| truth_channel: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
| class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]):
|
| """LHC particle-discovery POMDP environment."""
|
|
|
| SUPPORTS_CONCURRENT_SESSIONS = True
|
|
|
| def __init__(
|
| self,
|
| *,
|
| max_steps: int = 40,
|
| default_difficulty: Optional[str] = None,
|
| default_scenario_name: Optional[str] = None,
|
| reward_weights: Optional[RewardWeights] = None,
|
| ) -> None:
|
| super().__init__()
|
| self.max_steps = max_steps
|
| self.default_difficulty = default_difficulty
|
| self.default_scenario_name = default_scenario_name
|
| self.reward_weights = reward_weights or RewardWeights()
|
|
|
| self._state = CernState()
|
| self._scenario: Optional[Scenario] = None
|
| self._latent: Optional[FullLatentState] = None
|
| self._task: Optional[TaskSpec] = None
|
| self._noise: Optional[NoiseModel] = None
|
| self._output_gen: Optional[OutputGenerator] = None
|
| self._transition: Optional[TransitionEngine] = None
|
| self._rules: Optional[RulesEngine] = None
|
| self._history: List[PipelineStepRecord] = []
|
| self._all_outputs: List[IntermediateOutput] = []
|
|
|
|
|
|
|
| @property
|
| def state(self) -> CernState:
|
| return self._state
|
|
|
| def reset(
|
| self,
|
| seed: Optional[int] = None,
|
| episode_id: Optional[str] = None,
|
| **kwargs: Any,
|
| ) -> CollisionObservation:
|
| difficulty = kwargs.get("difficulty") or self.default_difficulty
|
| scenario_name = kwargs.get("scenario") or self.default_scenario_name
|
|
|
| scenario = sample_scenario(
|
| difficulty=difficulty,
|
| name=scenario_name,
|
| seed=seed,
|
| )
|
| self._scenario = scenario
|
| self._latent = scenario.fresh_latent()
|
| self._task = scenario.task
|
| if seed is not None:
|
| self._latent.rng_seed = int(seed)
|
| self._noise = NoiseModel(seed=self._latent.rng_seed)
|
| self._output_gen = OutputGenerator(self._noise)
|
| self._transition = TransitionEngine()
|
| self._rules = RulesEngine(
|
| mass_search_window_gev=tuple(self._task.mass_search_window_gev),
|
| )
|
| self._history = []
|
| self._all_outputs = []
|
|
|
| self._state = CernState(
|
| episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}",
|
| step_count=0,
|
| scenario_name=scenario.name,
|
| difficulty=scenario.difficulty,
|
| episode_done=False,
|
| cumulative_reward=0.0,
|
| truth_mass_gev=self._latent.particle.mass_gev,
|
| truth_channel=self._latent.particle.primary_channel,
|
| )
|
|
|
| obs = self._build_observation(
|
| latest_output=None,
|
| done=False,
|
| reward=0.0,
|
| step_breakdown={},
|
| rule_violations=[],
|
| )
|
| return obs
|
|
|
| def step(
|
| self,
|
| action: ExperimentAction,
|
| timeout_s: Optional[float] = None,
|
| **kwargs: Any,
|
| ) -> CollisionObservation:
|
| if self._latent is None:
|
| self.reset()
|
| if self._state.episode_done:
|
| return self._build_terminal_observation(reason="episode already complete")
|
|
|
| assert self._rules is not None
|
| assert self._output_gen is not None
|
| assert self._transition is not None
|
|
|
| prev_state = self._latent.model_copy(deep=True)
|
| rule_result = self._rules.validate(action, self._latent)
|
|
|
| if not rule_result.allowed:
|
| output = IntermediateOutput(
|
| output_type=OutputType.FAILURE_REPORT,
|
| step_index=self._state.step_count,
|
| success=False,
|
| quality_score=0.0,
|
| summary="Action rejected: " + "; ".join(rule_result.messages),
|
| warnings=rule_result.messages,
|
| )
|
| else:
|
| output = self._output_gen.generate(
|
| action=action,
|
| state=self._latent,
|
| step_index=self._state.step_count,
|
| )
|
|
|
|
|
| if rule_result.allowed:
|
| self._transition.step(self._latent, action, output)
|
| else:
|
| cost = compute_action_cost(action, output)
|
| self._latent.resources.budget_used_musd += cost["musd"]
|
| self._latent.resources.time_used_days += cost["days"]
|
| self._latent.step_count += 1
|
|
|
| self._all_outputs.append(output)
|
| cost = compute_action_cost(action, output)
|
| record = PipelineStepRecord(
|
| step_index=self._state.step_count,
|
| action_type=action.action_type,
|
| method=action.method,
|
| parameters=action.parameters,
|
| output_summary=output.summary,
|
| output_type=output.output_type,
|
| success=output.success,
|
| quality_score=float(output.quality_score),
|
| cost_musd=float(cost["musd"]),
|
| luminosity_cost_fb=float(cost["luminosity_fb"]),
|
| time_cost_days=float(cost["days"]),
|
| )
|
| self._history.append(record)
|
|
|
| step_reward = compute_step_reward(
|
| action=action,
|
| output=output,
|
| state_before=prev_state,
|
| state_after=self._latent,
|
| rule_result=rule_result,
|
| weights=self.reward_weights,
|
| )
|
|
|
| self._state.cumulative_reward += step_reward.reward
|
| self._state.step_count += 1
|
|
|
| terminal_now = (
|
| action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM
|
| and rule_result.allowed
|
| )
|
| time_up = (
|
| self._state.step_count >= self.max_steps
|
| or self._latent.resources.budget_exhausted
|
| or self._latent.resources.time_exhausted
|
| )
|
|
|
| terminal_reward_value = 0.0
|
| if terminal_now:
|
| claim = self._claim_from_action(action)
|
| term = compute_terminal_reward(
|
| state=self._latent,
|
| claim=claim,
|
| weights=self.reward_weights,
|
| )
|
| terminal_reward_value = term.reward
|
| self._state.cumulative_reward += terminal_reward_value
|
| self._state.terminal_reward = terminal_reward_value
|
| self._state.discovered = term.discovered
|
| self._state.correct_mass = term.correct_mass
|
| self._state.correct_channel = term.correct_channel
|
| self._state.correct_spin = term.correct_spin
|
|
|
| done = terminal_now or time_up
|
| if done:
|
| self._state.episode_done = True
|
|
|
| observation = self._build_observation(
|
| latest_output=output,
|
| done=done,
|
| reward=step_reward.reward + terminal_reward_value,
|
| step_breakdown=step_reward.breakdown.components,
|
| rule_violations=[
|
| *(v.value for v in rule_result.violations),
|
| *(v.value for v in rule_result.soft_violations),
|
| ],
|
| )
|
| return observation
|
|
|
|
|
|
|
| def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim:
|
| raw = action.parameters.get("claim") or {}
|
| try:
|
| return DiscoveryClaim(**raw)
|
| except Exception as exc:
|
| logger.warning("Malformed claim, defaulting: %s", exc)
|
| return DiscoveryClaim()
|
|
|
| def _build_terminal_observation(self, reason: str) -> CollisionObservation:
|
| obs = self._build_observation(
|
| latest_output=IntermediateOutput(
|
| output_type=OutputType.FAILURE_REPORT,
|
| step_index=self._state.step_count,
|
| success=False,
|
| summary=reason,
|
| ),
|
| done=True,
|
| reward=0.0,
|
| step_breakdown={},
|
| rule_violations=["episode_terminated"],
|
| )
|
| return obs
|
|
|
| def _build_observation(
|
| self,
|
| *,
|
| latest_output: Optional[IntermediateOutput],
|
| done: bool,
|
| reward: float,
|
| step_breakdown: dict,
|
| rule_violations: list,
|
| ) -> CollisionObservation:
|
| assert self._latent is not None
|
| assert self._task is not None
|
|
|
| res = self._latent.resources
|
| usage = ResourceUsage(
|
| budget_used_musd=res.budget_used_musd,
|
| budget_remaining_musd=res.budget_remaining,
|
| luminosity_used_fb=res.luminosity_used_fb,
|
| luminosity_remaining_fb=res.luminosity_remaining,
|
| time_used_days=res.time_used_days,
|
| time_remaining_days=res.time_remaining,
|
| compute_hours_used=res.compute_hours_used,
|
| )
|
|
|
| obs = CollisionObservation(
|
| done=done,
|
| reward=float(reward),
|
| task=self._task,
|
| step_index=self._state.step_count,
|
| pipeline_history=list(self._history),
|
| available_channels=self._task.available_channels,
|
| available_triggers=self._task.available_triggers,
|
| available_tools=self._task.available_tools,
|
| resource_usage=usage,
|
| latest_output=latest_output,
|
| all_outputs=list(self._all_outputs),
|
| candidate_masses_gev=list(self._latent.candidate_masses_gev),
|
| candidate_significances=list(self._latent.candidate_significances),
|
| selected_channel=self._latent.selected_channel,
|
| selected_beam_energy=self._latent.selected_beam_energy,
|
| cumulative_significance=float(
|
| self._latent.progress.best_significance_sigma or 0.0
|
| ),
|
| uncertainty_summary={
|
| "energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty,
|
| "luminosity_unc": self._latent.detector.luminosity_uncertainty,
|
| "resolution_gev": self._latent.detector.detector_resolution_gev,
|
| },
|
| rule_violations=rule_violations,
|
| step_reward_breakdown=dict(step_breakdown),
|
| )
|
| return obs
|
|
|
|
|
|
|
| def hidden_truth(self) -> Optional[dict]:
|
| """Reveal the hidden particle (debug / evaluation only)."""
|
| if self._latent is None:
|
| return None
|
| return self._latent.particle.model_dump()
|
|
|
|
|
| __all__ = [
|
| "CernState",
|
| "CERNCollisionEnvironment",
|
| "AGENT_ENVIRONMENT_RULES",
|
| "build_agent_system_prompt",
|
| ]
|
|
|