cernenv / server /simulator /transition.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""Pure-function transition engine.
Given a (latent_state, action, generated_output) triple, produces the next
latent state plus the deltas needed for the agent-visible observation. The
``TransitionEngine`` does **not** generate randomness directly; it consumes
artifacts from the ``OutputGenerator``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
from models import (
ActionType,
ExperimentAction,
IntermediateOutput,
OutputType,
)
from .latent_state import FullLatentState
# Per-action default cost in (millions of USD, days, compute hours)
ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity
ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
}
def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
"""Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
musd = float(base.get("musd", 0.0))
days = float(base.get("days", 0.0))
compute = float(base.get("compute", 0.0))
lumi_fb = 0.0
data = output.data or {}
if action.action_type == ActionType.COLLECT_COLLISIONS:
lumi_fb = float(data.get("luminosity_fb", 0.0))
musd += float(data.get("cost_musd", 0.0))
days += float(data.get("time_days", 0.0))
return {
"musd": musd,
"days": days,
"compute_hours": compute,
"luminosity_fb": lumi_fb,
}
@dataclass
class TransitionResult:
next_state: FullLatentState
realised_cost: Dict[str, float]
class TransitionEngine:
"""Applies an action's output to evolve the latent state."""
def step(
self,
state: FullLatentState,
action: ExperimentAction,
output: IntermediateOutput,
) -> TransitionResult:
# We mutate the live state in place, then return it. This is fine
# because the environment owns the only reference.
cost = compute_action_cost(action, output)
state.resources.budget_used_musd += cost["musd"]
state.resources.time_used_days += cost["days"]
state.resources.compute_hours_used += cost["compute_hours"]
state.resources.luminosity_used_fb += cost["luminosity_fb"]
if not output.success:
state.step_count += 1
return TransitionResult(next_state=state, realised_cost=cost)
a = action.action_type
data = output.data or {}
if a == ActionType.CONFIGURE_BEAM:
beam = data.get("beam_energy")
state.selected_beam_energy = beam
state.progress.beam_configured = True
elif a == ActionType.ALLOCATE_LUMINOSITY:
state.progress.luminosity_allocated = True
elif a == ActionType.SET_TRIGGER:
trig = data.get("trigger")
state.selected_trigger = trig
state.progress.trigger_set = True
elif a == ActionType.COLLECT_COLLISIONS:
state.progress.collisions_collected = True
state.progress.n_events_collected += int(
data.get("n_signal_candidates", 0)
) + int(data.get("n_background_estimate", 0))
state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
state.progress.best_channel = data.get("channel") or state.progress.best_channel
state.progress.best_beam_energy = (
data.get("beam_energy") or state.progress.best_beam_energy
)
elif a == ActionType.CALIBRATE_DETECTOR:
state.progress.detector_calibrated = True
state.detector.detector_calibrated = True
improvement = float(data.get("resolution_improvement", 0.0))
state.detector.detector_resolution_gev = max(
0.05,
state.detector.detector_resolution_gev * (1.0 - improvement),
)
elif a == ActionType.RECONSTRUCT_TRACKS:
state.progress.tracks_reconstructed = True
state.detector.tracker_aligned = True
elif a == ActionType.SELECT_CHANNEL:
channel = data.get("channel")
if channel:
state.selected_channel = channel
state.progress.channel_selected = True
elif a == ActionType.BUILD_INVARIANT_MASS:
state.progress.invariant_mass_built = True
elif a == ActionType.SUBTRACT_BACKGROUND:
state.progress.background_subtracted = True
elif a == ActionType.FIT_RESONANCE:
state.progress.resonance_fitted = True
m = float(data.get("fit_mass_gev", 0.0))
unc = float(data.get("fit_mass_unc_gev", 0.0))
w = float(data.get("fit_width_gev", 0.0))
if m > 0:
state.candidate_masses_gev.append(m)
state.candidate_significances.append(0.0)
state.progress.best_fit_mass_gev = m
state.progress.best_fit_width_gev = w
elif a == ActionType.SCAN_BUMP:
state.progress.bump_scanned = True
cm = float(data.get("candidate_mass_gev", 0.0))
if cm > 0:
state.candidate_masses_gev.append(cm)
state.candidate_significances.append(0.0)
elif a == ActionType.MEASURE_ANGULAR:
state.progress.angular_measured = True
elif a == ActionType.ESTIMATE_SIGNIFICANCE:
state.progress.significance_estimated = True
sig = float(data.get("significance_sigma", 0.0))
state.progress.best_significance_sigma = max(
state.progress.best_significance_sigma or 0.0, sig
)
if state.candidate_significances:
state.candidate_significances[-1] = sig
elif a == ActionType.REQUEST_SYSTEMATICS:
state.progress.systematics_requested = True
state.detector.energy_scale_uncertainty *= 0.6
state.detector.luminosity_uncertainty *= 0.7
elif a == ActionType.REQUEST_THEORY_REVIEW:
state.progress.theory_review_requested = True
elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
state.progress.claim_submitted = True
state.step_count += 1
return TransitionResult(next_state=state, realised_cost=cost)
__all__ = [
"ACTION_COSTS",
"TransitionEngine",
"TransitionResult",
"compute_action_cost",
]