cernenv-trainer / server /environment.py
anugrahhu's picture
sft+reward-fix: server/environment.py
70b06db verified
"""``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards.
This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes.
It owns one episode at a time:
reset(seed) → builds a fresh latent state from a sampled scenario.
step(action) → validates → generates noisy output → updates state →
computes reward → builds the agent observation.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any, List, Optional
from openenv.core.env_server import Environment
from models import (
AGENT_ENVIRONMENT_RULES,
ActionType,
CernState,
CollisionObservation,
DiscoveryClaim,
ExperimentAction,
IntermediateOutput,
OutputType,
PipelineStepRecord,
ResourceUsage,
TaskSpec,
build_agent_system_prompt,
)
from server.rewards import (
RewardWeights,
compute_step_reward,
compute_terminal_reward,
)
from server.rules import RulesEngine, ViolationCode
from server.simulator import (
NoiseModel,
OutputGenerator,
TransitionEngine,
compute_action_cost,
)
from server.simulator.latent_state import FullLatentState
from server.tasks import sample_scenario, Scenario
logger = logging.getLogger(__name__)
# ── Environment ──────────────────────────────────────────────────────────
class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]):
"""LHC particle-discovery POMDP environment."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
*,
max_steps: int = 40,
default_difficulty: Optional[str] = None,
default_scenario_name: Optional[str] = None,
reward_weights: Optional[RewardWeights] = None,
) -> None:
super().__init__()
self.max_steps = max_steps
self.default_difficulty = default_difficulty
self.default_scenario_name = default_scenario_name
self.reward_weights = reward_weights or RewardWeights()
self._state = CernState()
self._scenario: Optional[Scenario] = None
self._latent: Optional[FullLatentState] = None
self._task: Optional[TaskSpec] = None
self._noise: Optional[NoiseModel] = None
self._output_gen: Optional[OutputGenerator] = None
self._transition: Optional[TransitionEngine] = None
self._rules: Optional[RulesEngine] = None
self._history: List[PipelineStepRecord] = []
self._all_outputs: List[IntermediateOutput] = []
# ── Environment API ────────────────────────────────────────────────
@property
def state(self) -> CernState:
return self._state
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> CollisionObservation:
difficulty = kwargs.get("difficulty") or self.default_difficulty
scenario_name = kwargs.get("scenario") or self.default_scenario_name
scenario = sample_scenario(
difficulty=difficulty,
name=scenario_name,
seed=seed,
)
self._scenario = scenario
self._latent = scenario.fresh_latent()
self._task = scenario.task
if seed is not None:
self._latent.rng_seed = int(seed)
self._noise = NoiseModel(seed=self._latent.rng_seed)
self._output_gen = OutputGenerator(self._noise)
self._transition = TransitionEngine()
self._rules = RulesEngine(
mass_search_window_gev=tuple(self._task.mass_search_window_gev),
)
self._history = []
self._all_outputs = []
self._state = CernState(
episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}",
step_count=0,
scenario_name=scenario.name,
difficulty=scenario.difficulty,
episode_done=False,
cumulative_reward=0.0,
truth_mass_gev=self._latent.particle.mass_gev,
truth_channel=self._latent.particle.primary_channel,
)
obs = self._build_observation(
latest_output=None,
done=False,
reward=0.0,
step_breakdown={},
rule_violations=[],
)
return obs
def step(
self,
action: ExperimentAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CollisionObservation:
"""Apply one action and return the next observation.
``timeout_s`` is accepted for OpenEnv API compatibility but is a
no-op for this environment: each ``step`` is pure-compute (numpy
ops on a small latent state, sub-millisecond) and cannot hang.
The episode-level "sandbox" enforced here is *resource* exhaustion
— budget (M$), integrated luminosity (fb⁻¹), and wall-time-days —
which is checked at the bottom of this method and terminates the
episode via ``done=True`` when any limit is crossed. That is the
meaningful timeout for an LHC-discovery rollout.
If ``timeout_s`` is non-None we log it once at debug level so
callers can confirm their value is being received without changing
any runtime behaviour.
"""
if timeout_s is not None:
logger.debug(
"step() received timeout_s=%.3fs (informational; "
"actual cutoff is resource-exhaustion based)",
float(timeout_s),
)
if self._latent is None:
self.reset()
if self._state.episode_done:
return self._build_terminal_observation(reason="episode already complete")
assert self._rules is not None
assert self._output_gen is not None
assert self._transition is not None
prev_state = self._latent.model_copy(deep=True)
rule_result = self._rules.validate(action, self._latent)
if not rule_result.allowed:
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=self._state.step_count,
success=False,
quality_score=0.0,
summary="Action rejected: " + "; ".join(rule_result.messages),
warnings=rule_result.messages,
)
else:
output = self._output_gen.generate(
action=action,
state=self._latent,
step_index=self._state.step_count,
)
# Apply transition (state mutation + cost accounting)
if rule_result.allowed:
self._transition.step(self._latent, action, output)
else:
cost = compute_action_cost(action, output)
self._latent.resources.budget_used_musd += cost["musd"]
self._latent.resources.time_used_days += cost["days"]
self._latent.step_count += 1
self._all_outputs.append(output)
cost = compute_action_cost(action, output)
record = PipelineStepRecord(
step_index=self._state.step_count,
action_type=action.action_type,
method=action.method,
parameters=action.parameters,
output_summary=output.summary,
output_type=output.output_type,
success=output.success,
quality_score=float(output.quality_score),
cost_musd=float(cost["musd"]),
luminosity_cost_fb=float(cost["luminosity_fb"]),
time_cost_days=float(cost["days"]),
)
self._history.append(record)
step_reward = compute_step_reward(
action=action,
output=output,
state_before=prev_state,
state_after=self._latent,
rule_result=rule_result,
weights=self.reward_weights,
history=self._history[:-1], # exclude the record we just appended
)
self._state.cumulative_reward += step_reward.reward
self._state.step_count += 1
terminal_now = (
action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM
and rule_result.allowed
)
time_up = (
self._state.step_count >= self.max_steps
or self._latent.resources.budget_exhausted
or self._latent.resources.time_exhausted
)
terminal_reward_value = 0.0
if terminal_now:
claim = self._claim_from_action(action)
term = compute_terminal_reward(
state=self._latent,
claim=claim,
weights=self.reward_weights,
)
terminal_reward_value = term.reward
self._state.cumulative_reward += terminal_reward_value
self._state.terminal_reward = terminal_reward_value
self._state.discovered = term.discovered
self._state.correct_mass = term.correct_mass
self._state.correct_channel = term.correct_channel
self._state.correct_spin = term.correct_spin
elif time_up:
# Fix #1: if the episode runs out of steps/budget/time and the
# agent never even *attempted* a SUBMIT_DISCOVERY_CLAIM, levy a
# flat no-claim penalty so claim-avoidance can no longer
# dominate the per-step shaping reward (the v1 reward hack).
ever_claimed = any(
rec.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM
for rec in self._history
)
if not ever_claimed:
term = compute_terminal_reward(
state=self._latent,
claim=None,
weights=self.reward_weights,
)
terminal_reward_value = term.reward
self._state.cumulative_reward += terminal_reward_value
self._state.terminal_reward = terminal_reward_value
self._state.discovered = term.discovered
self._state.correct_mass = term.correct_mass
self._state.correct_channel = term.correct_channel
self._state.correct_spin = term.correct_spin
done = terminal_now or time_up
if done:
self._state.episode_done = True
observation = self._build_observation(
latest_output=output,
done=done,
reward=step_reward.reward + terminal_reward_value,
step_breakdown=step_reward.breakdown.components,
rule_violations=[
*(v.value for v in rule_result.violations),
*(v.value for v in rule_result.soft_violations),
],
)
return observation
# ── Helpers ────────────────────────────────────────────────────────
def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim:
raw = action.parameters.get("claim") or {}
try:
return DiscoveryClaim(**raw)
except Exception as exc: # pragma: no cover - defensive
logger.warning("Malformed claim, defaulting: %s", exc)
return DiscoveryClaim()
def _build_terminal_observation(self, reason: str) -> CollisionObservation:
obs = self._build_observation(
latest_output=IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=self._state.step_count,
success=False,
summary=reason,
),
done=True,
reward=0.0,
step_breakdown={},
rule_violations=["episode_terminated"],
)
return obs
def _build_observation(
self,
*,
latest_output: Optional[IntermediateOutput],
done: bool,
reward: float,
step_breakdown: dict,
rule_violations: list,
) -> CollisionObservation:
assert self._latent is not None
assert self._task is not None
res = self._latent.resources
usage = ResourceUsage(
budget_used_musd=res.budget_used_musd,
budget_remaining_musd=res.budget_remaining,
luminosity_used_fb=res.luminosity_used_fb,
luminosity_remaining_fb=res.luminosity_remaining,
time_used_days=res.time_used_days,
time_remaining_days=res.time_remaining,
compute_hours_used=res.compute_hours_used,
)
obs = CollisionObservation(
done=done,
reward=float(reward),
task=self._task,
step_index=self._state.step_count,
pipeline_history=list(self._history),
available_channels=self._task.available_channels,
available_triggers=self._task.available_triggers,
available_tools=self._task.available_tools,
resource_usage=usage,
latest_output=latest_output,
all_outputs=list(self._all_outputs),
candidate_masses_gev=list(self._latent.candidate_masses_gev),
candidate_significances=list(self._latent.candidate_significances),
selected_channel=self._latent.selected_channel,
selected_beam_energy=self._latent.selected_beam_energy,
cumulative_significance=float(
self._latent.progress.best_significance_sigma or 0.0
),
uncertainty_summary={
"energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty,
"luminosity_unc": self._latent.detector.luminosity_uncertainty,
"resolution_gev": self._latent.detector.detector_resolution_gev,
},
rule_violations=rule_violations,
step_reward_breakdown=dict(step_breakdown),
)
return obs
# ── Convenience for diagnostics ────────────────────────────────────
def hidden_truth(self) -> Optional[dict]:
"""Reveal the hidden particle (debug / evaluation only)."""
if self._latent is None:
return None
return self._latent.particle.model_dump()
__all__ = [
"CernState",
"CERNCollisionEnvironment",
"AGENT_ENVIRONMENT_RULES",
"build_agent_system_prompt",
]