| """Pure-function transition engine.
|
|
|
| Given a (latent_state, action, generated_output) triple, produces the next
|
| latent state plus the deltas needed for the agent-visible observation. The
|
| ``TransitionEngine`` does **not** generate randomness directly; it consumes
|
| artifacts from the ``OutputGenerator``.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from dataclasses import dataclass
|
| from typing import Dict
|
|
|
| from models import (
|
| ActionType,
|
| ExperimentAction,
|
| IntermediateOutput,
|
| OutputType,
|
| )
|
|
|
| from .latent_state import FullLatentState
|
|
|
|
|
|
|
| ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
|
| ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
|
| ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
|
| ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
|
| ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0},
|
| ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
|
| ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
|
| ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
|
| ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
|
| ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
|
| ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
|
| ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
|
| ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
|
| ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
|
| ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
|
| ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
|
| ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
|
| }
|
|
|
|
|
| def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
|
| """Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
|
| base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
|
| musd = float(base.get("musd", 0.0))
|
| days = float(base.get("days", 0.0))
|
| compute = float(base.get("compute", 0.0))
|
| lumi_fb = 0.0
|
|
|
| data = output.data or {}
|
| if action.action_type == ActionType.COLLECT_COLLISIONS:
|
| lumi_fb = float(data.get("luminosity_fb", 0.0))
|
| musd += float(data.get("cost_musd", 0.0))
|
| days += float(data.get("time_days", 0.0))
|
|
|
| return {
|
| "musd": musd,
|
| "days": days,
|
| "compute_hours": compute,
|
| "luminosity_fb": lumi_fb,
|
| }
|
|
|
|
|
| @dataclass
|
| class TransitionResult:
|
| next_state: FullLatentState
|
| realised_cost: Dict[str, float]
|
|
|
|
|
| class TransitionEngine:
|
| """Applies an action's output to evolve the latent state."""
|
|
|
| def step(
|
| self,
|
| state: FullLatentState,
|
| action: ExperimentAction,
|
| output: IntermediateOutput,
|
| ) -> TransitionResult:
|
|
|
|
|
| cost = compute_action_cost(action, output)
|
| state.resources.budget_used_musd += cost["musd"]
|
| state.resources.time_used_days += cost["days"]
|
| state.resources.compute_hours_used += cost["compute_hours"]
|
| state.resources.luminosity_used_fb += cost["luminosity_fb"]
|
|
|
| if not output.success:
|
| state.step_count += 1
|
| return TransitionResult(next_state=state, realised_cost=cost)
|
|
|
| a = action.action_type
|
| data = output.data or {}
|
|
|
| if a == ActionType.CONFIGURE_BEAM:
|
| beam = data.get("beam_energy")
|
| state.selected_beam_energy = beam
|
| state.progress.beam_configured = True
|
|
|
| elif a == ActionType.ALLOCATE_LUMINOSITY:
|
| state.progress.luminosity_allocated = True
|
|
|
| elif a == ActionType.SET_TRIGGER:
|
| trig = data.get("trigger")
|
| state.selected_trigger = trig
|
| state.progress.trigger_set = True
|
|
|
| elif a == ActionType.COLLECT_COLLISIONS:
|
| state.progress.collisions_collected = True
|
| state.progress.n_events_collected += int(
|
| data.get("n_signal_candidates", 0)
|
| ) + int(data.get("n_background_estimate", 0))
|
| state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
|
| state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
|
| state.progress.best_channel = data.get("channel") or state.progress.best_channel
|
| state.progress.best_beam_energy = (
|
| data.get("beam_energy") or state.progress.best_beam_energy
|
| )
|
|
|
| elif a == ActionType.CALIBRATE_DETECTOR:
|
| state.progress.detector_calibrated = True
|
| state.detector.detector_calibrated = True
|
| improvement = float(data.get("resolution_improvement", 0.0))
|
| state.detector.detector_resolution_gev = max(
|
| 0.05,
|
| state.detector.detector_resolution_gev * (1.0 - improvement),
|
| )
|
|
|
| elif a == ActionType.RECONSTRUCT_TRACKS:
|
| state.progress.tracks_reconstructed = True
|
| state.detector.tracker_aligned = True
|
|
|
| elif a == ActionType.SELECT_CHANNEL:
|
| channel = data.get("channel")
|
| if channel:
|
| state.selected_channel = channel
|
| state.progress.channel_selected = True
|
|
|
| elif a == ActionType.BUILD_INVARIANT_MASS:
|
| state.progress.invariant_mass_built = True
|
|
|
| elif a == ActionType.SUBTRACT_BACKGROUND:
|
| state.progress.background_subtracted = True
|
|
|
| elif a == ActionType.FIT_RESONANCE:
|
| state.progress.resonance_fitted = True
|
| m = float(data.get("fit_mass_gev", 0.0))
|
| unc = float(data.get("fit_mass_unc_gev", 0.0))
|
| w = float(data.get("fit_width_gev", 0.0))
|
| if m > 0:
|
| state.candidate_masses_gev.append(m)
|
| state.candidate_significances.append(0.0)
|
| state.progress.best_fit_mass_gev = m
|
| state.progress.best_fit_width_gev = w
|
|
|
| elif a == ActionType.SCAN_BUMP:
|
| state.progress.bump_scanned = True
|
| cm = float(data.get("candidate_mass_gev", 0.0))
|
| if cm > 0:
|
| state.candidate_masses_gev.append(cm)
|
| state.candidate_significances.append(0.0)
|
|
|
| elif a == ActionType.MEASURE_ANGULAR:
|
| state.progress.angular_measured = True
|
|
|
| elif a == ActionType.ESTIMATE_SIGNIFICANCE:
|
| state.progress.significance_estimated = True
|
| sig = float(data.get("significance_sigma", 0.0))
|
| state.progress.best_significance_sigma = max(
|
| state.progress.best_significance_sigma or 0.0, sig
|
| )
|
| if state.candidate_significances:
|
| state.candidate_significances[-1] = sig
|
|
|
| elif a == ActionType.REQUEST_SYSTEMATICS:
|
| state.progress.systematics_requested = True
|
| state.detector.energy_scale_uncertainty *= 0.6
|
| state.detector.luminosity_uncertainty *= 0.7
|
|
|
| elif a == ActionType.REQUEST_THEORY_REVIEW:
|
| state.progress.theory_review_requested = True
|
|
|
| elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| state.progress.claim_submitted = True
|
|
|
| state.step_count += 1
|
| return TransitionResult(next_state=state, realised_cost=cost)
|
|
|
|
|
| __all__ = [
|
| "ACTION_COSTS",
|
| "TransitionEngine",
|
| "TransitionResult",
|
| "compute_action_cost",
|
| ]
|
|
|