anugrahhu commited on
Commit
7df4308
·
verified ·
1 Parent(s): 11307a1

fix: coerce beam_energy to str so CollisionObservation pydantic check accepts numeric LLM outputs

Browse files
Files changed (1) hide show
  1. server/simulator/transition.py +205 -197
server/simulator/transition.py CHANGED
@@ -1,197 +1,205 @@
1
- """Pure-function transition engine.
2
-
3
- Given a (latent_state, action, generated_output) triple, produces the next
4
- latent state plus the deltas needed for the agent-visible observation. The
5
- ``TransitionEngine`` does **not** generate randomness directly; it consumes
6
- artifacts from the ``OutputGenerator``.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- from dataclasses import dataclass
12
- from typing import Dict
13
-
14
- from models import (
15
- ActionType,
16
- ExperimentAction,
17
- IntermediateOutput,
18
- OutputType,
19
- )
20
-
21
- from .latent_state import FullLatentState
22
-
23
-
24
- # Per-action default cost in (millions of USD, days, compute hours)
25
- ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
26
- ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
27
- ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
28
- ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
29
- ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity
30
- ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
31
- ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
32
- ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
33
- ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
34
- ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
35
- ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
36
- ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
37
- ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
38
- ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
39
- ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
40
- ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
41
- ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
42
- }
43
-
44
-
45
- def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
46
- """Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
47
- base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
48
- musd = float(base.get("musd", 0.0))
49
- days = float(base.get("days", 0.0))
50
- compute = float(base.get("compute", 0.0))
51
- lumi_fb = 0.0
52
-
53
- data = output.data or {}
54
- if action.action_type == ActionType.COLLECT_COLLISIONS:
55
- lumi_fb = float(data.get("luminosity_fb", 0.0))
56
- musd += float(data.get("cost_musd", 0.0))
57
- days += float(data.get("time_days", 0.0))
58
-
59
- return {
60
- "musd": musd,
61
- "days": days,
62
- "compute_hours": compute,
63
- "luminosity_fb": lumi_fb,
64
- }
65
-
66
-
67
- @dataclass
68
- class TransitionResult:
69
- next_state: FullLatentState
70
- realised_cost: Dict[str, float]
71
-
72
-
73
- class TransitionEngine:
74
- """Applies an action's output to evolve the latent state."""
75
-
76
- def step(
77
- self,
78
- state: FullLatentState,
79
- action: ExperimentAction,
80
- output: IntermediateOutput,
81
- ) -> TransitionResult:
82
- # We mutate the live state in place, then return it. This is fine
83
- # because the environment owns the only reference.
84
- cost = compute_action_cost(action, output)
85
- state.resources.budget_used_musd += cost["musd"]
86
- state.resources.time_used_days += cost["days"]
87
- state.resources.compute_hours_used += cost["compute_hours"]
88
- state.resources.luminosity_used_fb += cost["luminosity_fb"]
89
-
90
- if not output.success:
91
- state.step_count += 1
92
- return TransitionResult(next_state=state, realised_cost=cost)
93
-
94
- a = action.action_type
95
- data = output.data or {}
96
-
97
- if a == ActionType.CONFIGURE_BEAM:
98
- beam = data.get("beam_energy")
99
- state.selected_beam_energy = beam
100
- state.progress.beam_configured = True
101
-
102
- elif a == ActionType.ALLOCATE_LUMINOSITY:
103
- state.progress.luminosity_allocated = True
104
-
105
- elif a == ActionType.SET_TRIGGER:
106
- trig = data.get("trigger")
107
- state.selected_trigger = trig
108
- state.progress.trigger_set = True
109
-
110
- elif a == ActionType.COLLECT_COLLISIONS:
111
- state.progress.collisions_collected = True
112
- state.progress.n_events_collected += int(
113
- data.get("n_signal_candidates", 0)
114
- ) + int(data.get("n_background_estimate", 0))
115
- state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
116
- state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
117
- state.progress.best_channel = data.get("channel") or state.progress.best_channel
118
- state.progress.best_beam_energy = (
119
- data.get("beam_energy") or state.progress.best_beam_energy
120
- )
121
-
122
- elif a == ActionType.CALIBRATE_DETECTOR:
123
- state.progress.detector_calibrated = True
124
- state.detector.detector_calibrated = True
125
- improvement = float(data.get("resolution_improvement", 0.0))
126
- state.detector.detector_resolution_gev = max(
127
- 0.05,
128
- state.detector.detector_resolution_gev * (1.0 - improvement),
129
- )
130
-
131
- elif a == ActionType.RECONSTRUCT_TRACKS:
132
- state.progress.tracks_reconstructed = True
133
- state.detector.tracker_aligned = True
134
-
135
- elif a == ActionType.SELECT_CHANNEL:
136
- channel = data.get("channel")
137
- if channel:
138
- state.selected_channel = channel
139
- state.progress.channel_selected = True
140
-
141
- elif a == ActionType.BUILD_INVARIANT_MASS:
142
- state.progress.invariant_mass_built = True
143
-
144
- elif a == ActionType.SUBTRACT_BACKGROUND:
145
- state.progress.background_subtracted = True
146
-
147
- elif a == ActionType.FIT_RESONANCE:
148
- state.progress.resonance_fitted = True
149
- m = float(data.get("fit_mass_gev", 0.0))
150
- unc = float(data.get("fit_mass_unc_gev", 0.0))
151
- w = float(data.get("fit_width_gev", 0.0))
152
- if m > 0:
153
- state.candidate_masses_gev.append(m)
154
- state.candidate_significances.append(0.0)
155
- state.progress.best_fit_mass_gev = m
156
- state.progress.best_fit_width_gev = w
157
-
158
- elif a == ActionType.SCAN_BUMP:
159
- state.progress.bump_scanned = True
160
- cm = float(data.get("candidate_mass_gev", 0.0))
161
- if cm > 0:
162
- state.candidate_masses_gev.append(cm)
163
- state.candidate_significances.append(0.0)
164
-
165
- elif a == ActionType.MEASURE_ANGULAR:
166
- state.progress.angular_measured = True
167
-
168
- elif a == ActionType.ESTIMATE_SIGNIFICANCE:
169
- state.progress.significance_estimated = True
170
- sig = float(data.get("significance_sigma", 0.0))
171
- state.progress.best_significance_sigma = max(
172
- state.progress.best_significance_sigma or 0.0, sig
173
- )
174
- if state.candidate_significances:
175
- state.candidate_significances[-1] = sig
176
-
177
- elif a == ActionType.REQUEST_SYSTEMATICS:
178
- state.progress.systematics_requested = True
179
- state.detector.energy_scale_uncertainty *= 0.6
180
- state.detector.luminosity_uncertainty *= 0.7
181
-
182
- elif a == ActionType.REQUEST_THEORY_REVIEW:
183
- state.progress.theory_review_requested = True
184
-
185
- elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
186
- state.progress.claim_submitted = True
187
-
188
- state.step_count += 1
189
- return TransitionResult(next_state=state, realised_cost=cost)
190
-
191
-
192
- __all__ = [
193
- "ACTION_COSTS",
194
- "TransitionEngine",
195
- "TransitionResult",
196
- "compute_action_cost",
197
- ]
 
 
 
 
 
 
 
 
 
1
+ """Pure-function transition engine.
2
+
3
+ Given a (latent_state, action, generated_output) triple, produces the next
4
+ latent state plus the deltas needed for the agent-visible observation. The
5
+ ``TransitionEngine`` does **not** generate randomness directly; it consumes
6
+ artifacts from the ``OutputGenerator``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Dict
13
+
14
+ from models import (
15
+ ActionType,
16
+ ExperimentAction,
17
+ IntermediateOutput,
18
+ OutputType,
19
+ )
20
+
21
+ from .latent_state import FullLatentState
22
+
23
+
24
+ # Per-action default cost in (millions of USD, days, compute hours)
25
+ ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
26
+ ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
27
+ ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
28
+ ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
29
+ ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity
30
+ ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
31
+ ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
32
+ ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
33
+ ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
34
+ ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
35
+ ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
36
+ ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
37
+ ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
38
+ ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
39
+ ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
40
+ ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
41
+ ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
42
+ }
43
+
44
+
45
+ def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
46
+ """Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
47
+ base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
48
+ musd = float(base.get("musd", 0.0))
49
+ days = float(base.get("days", 0.0))
50
+ compute = float(base.get("compute", 0.0))
51
+ lumi_fb = 0.0
52
+
53
+ data = output.data or {}
54
+ if action.action_type == ActionType.COLLECT_COLLISIONS:
55
+ lumi_fb = float(data.get("luminosity_fb", 0.0))
56
+ musd += float(data.get("cost_musd", 0.0))
57
+ days += float(data.get("time_days", 0.0))
58
+
59
+ return {
60
+ "musd": musd,
61
+ "days": days,
62
+ "compute_hours": compute,
63
+ "luminosity_fb": lumi_fb,
64
+ }
65
+
66
+
67
+ @dataclass
68
+ class TransitionResult:
69
+ next_state: FullLatentState
70
+ realised_cost: Dict[str, float]
71
+
72
+
73
+ class TransitionEngine:
74
+ """Applies an action's output to evolve the latent state."""
75
+
76
+ def step(
77
+ self,
78
+ state: FullLatentState,
79
+ action: ExperimentAction,
80
+ output: IntermediateOutput,
81
+ ) -> TransitionResult:
82
+ # We mutate the live state in place, then return it. This is fine
83
+ # because the environment owns the only reference.
84
+ cost = compute_action_cost(action, output)
85
+ state.resources.budget_used_musd += cost["musd"]
86
+ state.resources.time_used_days += cost["days"]
87
+ state.resources.compute_hours_used += cost["compute_hours"]
88
+ state.resources.luminosity_used_fb += cost["luminosity_fb"]
89
+
90
+ if not output.success:
91
+ state.step_count += 1
92
+ return TransitionResult(next_state=state, realised_cost=cost)
93
+
94
+ a = action.action_type
95
+ data = output.data or {}
96
+
97
+ if a == ActionType.CONFIGURE_BEAM:
98
+ beam = data.get("beam_energy")
99
+ # latent_state.selected_beam_energy is typed Optional[str] and
100
+ # CollisionObservation re-validates it as a str; LLM completions
101
+ # sometimes emit numeric beam_energy (e.g. 13.0), which would
102
+ # later fail Pydantic string validation in _build_observation.
103
+ # Coerce to str at the source so all downstream consumers
104
+ # (latent state, observation, output_generator) see a string.
105
+ state.selected_beam_energy = str(beam) if beam is not None else None
106
+ state.progress.beam_configured = True
107
+
108
+ elif a == ActionType.ALLOCATE_LUMINOSITY:
109
+ state.progress.luminosity_allocated = True
110
+
111
+ elif a == ActionType.SET_TRIGGER:
112
+ trig = data.get("trigger")
113
+ state.selected_trigger = trig
114
+ state.progress.trigger_set = True
115
+
116
+ elif a == ActionType.COLLECT_COLLISIONS:
117
+ state.progress.collisions_collected = True
118
+ state.progress.n_events_collected += int(
119
+ data.get("n_signal_candidates", 0)
120
+ ) + int(data.get("n_background_estimate", 0))
121
+ state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
122
+ state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
123
+ state.progress.best_channel = data.get("channel") or state.progress.best_channel
124
+ _be = data.get("beam_energy")
125
+ state.progress.best_beam_energy = (
126
+ (str(_be) if _be is not None else None)
127
+ or state.progress.best_beam_energy
128
+ )
129
+
130
+ elif a == ActionType.CALIBRATE_DETECTOR:
131
+ state.progress.detector_calibrated = True
132
+ state.detector.detector_calibrated = True
133
+ improvement = float(data.get("resolution_improvement", 0.0))
134
+ state.detector.detector_resolution_gev = max(
135
+ 0.05,
136
+ state.detector.detector_resolution_gev * (1.0 - improvement),
137
+ )
138
+
139
+ elif a == ActionType.RECONSTRUCT_TRACKS:
140
+ state.progress.tracks_reconstructed = True
141
+ state.detector.tracker_aligned = True
142
+
143
+ elif a == ActionType.SELECT_CHANNEL:
144
+ channel = data.get("channel")
145
+ if channel:
146
+ state.selected_channel = channel
147
+ state.progress.channel_selected = True
148
+
149
+ elif a == ActionType.BUILD_INVARIANT_MASS:
150
+ state.progress.invariant_mass_built = True
151
+
152
+ elif a == ActionType.SUBTRACT_BACKGROUND:
153
+ state.progress.background_subtracted = True
154
+
155
+ elif a == ActionType.FIT_RESONANCE:
156
+ state.progress.resonance_fitted = True
157
+ m = float(data.get("fit_mass_gev", 0.0))
158
+ unc = float(data.get("fit_mass_unc_gev", 0.0))
159
+ w = float(data.get("fit_width_gev", 0.0))
160
+ if m > 0:
161
+ state.candidate_masses_gev.append(m)
162
+ state.candidate_significances.append(0.0)
163
+ state.progress.best_fit_mass_gev = m
164
+ state.progress.best_fit_width_gev = w
165
+
166
+ elif a == ActionType.SCAN_BUMP:
167
+ state.progress.bump_scanned = True
168
+ cm = float(data.get("candidate_mass_gev", 0.0))
169
+ if cm > 0:
170
+ state.candidate_masses_gev.append(cm)
171
+ state.candidate_significances.append(0.0)
172
+
173
+ elif a == ActionType.MEASURE_ANGULAR:
174
+ state.progress.angular_measured = True
175
+
176
+ elif a == ActionType.ESTIMATE_SIGNIFICANCE:
177
+ state.progress.significance_estimated = True
178
+ sig = float(data.get("significance_sigma", 0.0))
179
+ state.progress.best_significance_sigma = max(
180
+ state.progress.best_significance_sigma or 0.0, sig
181
+ )
182
+ if state.candidate_significances:
183
+ state.candidate_significances[-1] = sig
184
+
185
+ elif a == ActionType.REQUEST_SYSTEMATICS:
186
+ state.progress.systematics_requested = True
187
+ state.detector.energy_scale_uncertainty *= 0.6
188
+ state.detector.luminosity_uncertainty *= 0.7
189
+
190
+ elif a == ActionType.REQUEST_THEORY_REVIEW:
191
+ state.progress.theory_review_requested = True
192
+
193
+ elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
194
+ state.progress.claim_submitted = True
195
+
196
+ state.step_count += 1
197
+ return TransitionResult(next_state=state, realised_cost=cost)
198
+
199
+
200
+ __all__ = [
201
+ "ACTION_COSTS",
202
+ "TransitionEngine",
203
+ "TransitionResult",
204
+ "compute_action_cost",
205
+ ]