"""Pure-function transition engine. Given a (latent_state, action, generated_output) triple, produces the next latent state plus the deltas needed for the agent-visible observation. The ``TransitionEngine`` does **not** generate randomness directly; it consumes artifacts from the ``OutputGenerator``. """ from __future__ import annotations from dataclasses import dataclass from typing import Dict from models import ( ActionType, ExperimentAction, IntermediateOutput, OutputType, ) from .latent_state import FullLatentState # Per-action default cost in (millions of USD, days, compute hours) ACTION_COSTS: Dict[ActionType, Dict[str, float]] = { ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1}, ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0}, ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0}, ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5}, ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0}, ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0}, ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0}, ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5}, ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5}, ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5}, ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5}, ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2}, ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0}, ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0}, ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0}, } def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]: """Return realised (musd, days, compute_hours, luminosity_fb) for this action.""" base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0}) musd = float(base.get("musd", 0.0)) days = float(base.get("days", 0.0)) compute = float(base.get("compute", 0.0)) lumi_fb = 0.0 data = output.data or {} if action.action_type == ActionType.COLLECT_COLLISIONS: lumi_fb = float(data.get("luminosity_fb", 0.0)) musd += float(data.get("cost_musd", 0.0)) days += float(data.get("time_days", 0.0)) return { "musd": musd, "days": days, "compute_hours": compute, "luminosity_fb": lumi_fb, } @dataclass class TransitionResult: next_state: FullLatentState realised_cost: Dict[str, float] class TransitionEngine: """Applies an action's output to evolve the latent state.""" def step( self, state: FullLatentState, action: ExperimentAction, output: IntermediateOutput, ) -> TransitionResult: # We mutate the live state in place, then return it. This is fine # because the environment owns the only reference. cost = compute_action_cost(action, output) state.resources.budget_used_musd += cost["musd"] state.resources.time_used_days += cost["days"] state.resources.compute_hours_used += cost["compute_hours"] state.resources.luminosity_used_fb += cost["luminosity_fb"] if not output.success: state.step_count += 1 return TransitionResult(next_state=state, realised_cost=cost) a = action.action_type data = output.data or {} if a == ActionType.CONFIGURE_BEAM: beam = data.get("beam_energy") # latent_state.selected_beam_energy is typed Optional[str] and # CollisionObservation re-validates it as a str; LLM completions # sometimes emit numeric beam_energy (e.g. 13.0), which would # later fail Pydantic string validation in _build_observation. # Coerce to str at the source so all downstream consumers # (latent state, observation, output_generator) see a string. state.selected_beam_energy = str(beam) if beam is not None else None state.progress.beam_configured = True elif a == ActionType.ALLOCATE_LUMINOSITY: state.progress.luminosity_allocated = True elif a == ActionType.SET_TRIGGER: trig = data.get("trigger") state.selected_trigger = trig state.progress.trigger_set = True elif a == ActionType.COLLECT_COLLISIONS: state.progress.collisions_collected = True state.progress.n_events_collected += int( data.get("n_signal_candidates", 0) ) + int(data.get("n_background_estimate", 0)) state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0)) state.progress.n_background_estimate += int(data.get("n_background_estimate", 0)) state.progress.best_channel = data.get("channel") or state.progress.best_channel _be = data.get("beam_energy") state.progress.best_beam_energy = ( (str(_be) if _be is not None else None) or state.progress.best_beam_energy ) elif a == ActionType.CALIBRATE_DETECTOR: state.progress.detector_calibrated = True state.detector.detector_calibrated = True improvement = float(data.get("resolution_improvement", 0.0)) state.detector.detector_resolution_gev = max( 0.05, state.detector.detector_resolution_gev * (1.0 - improvement), ) elif a == ActionType.RECONSTRUCT_TRACKS: state.progress.tracks_reconstructed = True state.detector.tracker_aligned = True elif a == ActionType.SELECT_CHANNEL: channel = data.get("channel") if channel: state.selected_channel = channel state.progress.channel_selected = True elif a == ActionType.BUILD_INVARIANT_MASS: state.progress.invariant_mass_built = True elif a == ActionType.SUBTRACT_BACKGROUND: state.progress.background_subtracted = True elif a == ActionType.FIT_RESONANCE: state.progress.resonance_fitted = True m = float(data.get("fit_mass_gev", 0.0)) unc = float(data.get("fit_mass_unc_gev", 0.0)) w = float(data.get("fit_width_gev", 0.0)) if m > 0: state.candidate_masses_gev.append(m) state.candidate_significances.append(0.0) state.progress.best_fit_mass_gev = m state.progress.best_fit_width_gev = w elif a == ActionType.SCAN_BUMP: state.progress.bump_scanned = True cm = float(data.get("candidate_mass_gev", 0.0)) if cm > 0: state.candidate_masses_gev.append(cm) state.candidate_significances.append(0.0) elif a == ActionType.MEASURE_ANGULAR: state.progress.angular_measured = True elif a == ActionType.ESTIMATE_SIGNIFICANCE: state.progress.significance_estimated = True sig = float(data.get("significance_sigma", 0.0)) state.progress.best_significance_sigma = max( state.progress.best_significance_sigma or 0.0, sig ) if state.candidate_significances: state.candidate_significances[-1] = sig elif a == ActionType.REQUEST_SYSTEMATICS: state.progress.systematics_requested = True state.detector.energy_scale_uncertainty *= 0.6 state.detector.luminosity_uncertainty *= 0.7 elif a == ActionType.REQUEST_THEORY_REVIEW: state.progress.theory_review_requested = True elif a == ActionType.SUBMIT_DISCOVERY_CLAIM: state.progress.claim_submitted = True state.step_count += 1 return TransitionResult(next_state=state, realised_cost=cost) __all__ = [ "ACTION_COSTS", "TransitionEngine", "TransitionResult", "compute_action_cost", ]