File size: 8,094 Bytes
2b0bffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """Pure-function transition engine.
Given a (latent_state, action, generated_output) triple, produces the next
latent state plus the deltas needed for the agent-visible observation. The
``TransitionEngine`` does **not** generate randomness directly; it consumes
artifacts from the ``OutputGenerator``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
from models import (
ActionType,
ExperimentAction,
IntermediateOutput,
OutputType,
)
from .latent_state import FullLatentState
# Per-action default cost in (millions of USD, days, compute hours)
ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity
ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
}
def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
"""Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
musd = float(base.get("musd", 0.0))
days = float(base.get("days", 0.0))
compute = float(base.get("compute", 0.0))
lumi_fb = 0.0
data = output.data or {}
if action.action_type == ActionType.COLLECT_COLLISIONS:
lumi_fb = float(data.get("luminosity_fb", 0.0))
musd += float(data.get("cost_musd", 0.0))
days += float(data.get("time_days", 0.0))
return {
"musd": musd,
"days": days,
"compute_hours": compute,
"luminosity_fb": lumi_fb,
}
@dataclass
class TransitionResult:
next_state: FullLatentState
realised_cost: Dict[str, float]
class TransitionEngine:
"""Applies an action's output to evolve the latent state."""
def step(
self,
state: FullLatentState,
action: ExperimentAction,
output: IntermediateOutput,
) -> TransitionResult:
# We mutate the live state in place, then return it. This is fine
# because the environment owns the only reference.
cost = compute_action_cost(action, output)
state.resources.budget_used_musd += cost["musd"]
state.resources.time_used_days += cost["days"]
state.resources.compute_hours_used += cost["compute_hours"]
state.resources.luminosity_used_fb += cost["luminosity_fb"]
if not output.success:
state.step_count += 1
return TransitionResult(next_state=state, realised_cost=cost)
a = action.action_type
data = output.data or {}
if a == ActionType.CONFIGURE_BEAM:
beam = data.get("beam_energy")
state.selected_beam_energy = beam
state.progress.beam_configured = True
elif a == ActionType.ALLOCATE_LUMINOSITY:
state.progress.luminosity_allocated = True
elif a == ActionType.SET_TRIGGER:
trig = data.get("trigger")
state.selected_trigger = trig
state.progress.trigger_set = True
elif a == ActionType.COLLECT_COLLISIONS:
state.progress.collisions_collected = True
state.progress.n_events_collected += int(
data.get("n_signal_candidates", 0)
) + int(data.get("n_background_estimate", 0))
state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
state.progress.best_channel = data.get("channel") or state.progress.best_channel
state.progress.best_beam_energy = (
data.get("beam_energy") or state.progress.best_beam_energy
)
elif a == ActionType.CALIBRATE_DETECTOR:
state.progress.detector_calibrated = True
state.detector.detector_calibrated = True
improvement = float(data.get("resolution_improvement", 0.0))
state.detector.detector_resolution_gev = max(
0.05,
state.detector.detector_resolution_gev * (1.0 - improvement),
)
elif a == ActionType.RECONSTRUCT_TRACKS:
state.progress.tracks_reconstructed = True
state.detector.tracker_aligned = True
elif a == ActionType.SELECT_CHANNEL:
channel = data.get("channel")
if channel:
state.selected_channel = channel
state.progress.channel_selected = True
elif a == ActionType.BUILD_INVARIANT_MASS:
state.progress.invariant_mass_built = True
elif a == ActionType.SUBTRACT_BACKGROUND:
state.progress.background_subtracted = True
elif a == ActionType.FIT_RESONANCE:
state.progress.resonance_fitted = True
m = float(data.get("fit_mass_gev", 0.0))
unc = float(data.get("fit_mass_unc_gev", 0.0))
w = float(data.get("fit_width_gev", 0.0))
if m > 0:
state.candidate_masses_gev.append(m)
state.candidate_significances.append(0.0)
state.progress.best_fit_mass_gev = m
state.progress.best_fit_width_gev = w
elif a == ActionType.SCAN_BUMP:
state.progress.bump_scanned = True
cm = float(data.get("candidate_mass_gev", 0.0))
if cm > 0:
state.candidate_masses_gev.append(cm)
state.candidate_significances.append(0.0)
elif a == ActionType.MEASURE_ANGULAR:
state.progress.angular_measured = True
elif a == ActionType.ESTIMATE_SIGNIFICANCE:
state.progress.significance_estimated = True
sig = float(data.get("significance_sigma", 0.0))
state.progress.best_significance_sigma = max(
state.progress.best_significance_sigma or 0.0, sig
)
if state.candidate_significances:
state.candidate_significances[-1] = sig
elif a == ActionType.REQUEST_SYSTEMATICS:
state.progress.systematics_requested = True
state.detector.energy_scale_uncertainty *= 0.6
state.detector.luminosity_uncertainty *= 0.7
elif a == ActionType.REQUEST_THEORY_REVIEW:
state.progress.theory_review_requested = True
elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
state.progress.claim_submitted = True
state.step_count += 1
return TransitionResult(next_state=state, realised_cost=cost)
__all__ = [
"ACTION_COSTS",
"TransitionEngine",
"TransitionResult",
"compute_action_cost",
]
|