File size: 9,120 Bytes
2b0bffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""RulesEngine for CERNenv.



Validates an incoming ``ExperimentAction`` against the current latent state

*before* it is executed. Rule violations are reported back as warnings on the

observation and feed into the per-step penalty in the reward function.

"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

from models import (
    ActionType,
    DetectorChannel,
    ExperimentAction,
    TriggerType,
)

from server.simulator.latent_state import FullLatentState


class ViolationCode(str, Enum):
    PREREQ_MISSING = "prerequisite_missing"
    BUDGET_EXHAUSTED = "budget_exhausted"
    LUMI_EXHAUSTED = "luminosity_exhausted"
    TIME_EXHAUSTED = "time_exhausted"
    REDUNDANT = "redundant"
    INVALID_PARAMS = "invalid_parameters"
    INVALID_CLAIM = "invalid_claim"
    CHANNEL_MISMATCH = "channel_mismatch"
    OUT_OF_WINDOW = "out_of_search_window"


@dataclass
class RuleResult:
    allowed: bool
    violations: List[ViolationCode] = field(default_factory=list)
    messages: List[str] = field(default_factory=list)
    soft_violations: List[ViolationCode] = field(default_factory=list)

    def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None:
        self.messages.append(msg)
        if soft:
            self.soft_violations.append(code)
        else:
            self.violations.append(code)
            self.allowed = False


class RulesEngine:
    """Stateless validator (state is passed in)."""

    def __init__(

        self,

        mass_search_window_gev: tuple[float, float] = (50.0, 1000.0),

    ) -> None:
        self.mass_search_window_gev = mass_search_window_gev

    # ── Public API ─────────────────────────────────────────────────────

    def validate(

        self,

        action: ExperimentAction,

        state: FullLatentState,

    ) -> RuleResult:
        result = RuleResult(allowed=True)

        # ── resource gating (hard) ────────────────────────────────
        if state.resources.budget_exhausted:
            result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.")
        if state.resources.time_exhausted:
            result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.")
        # luminosity exhaustion only blocks DAQ-style actions
        if (
            state.resources.luminosity_exhausted
            and action.action_type in {
                ActionType.ALLOCATE_LUMINOSITY,
                ActionType.COLLECT_COLLISIONS,
            }
        ):
            result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.")

        if not result.allowed:
            return result

        a = action.action_type
        prog = state.progress

        # ── prerequisites ──────────────────────────────────────────
        if a == ActionType.COLLECT_COLLISIONS:
            if not prog.beam_configured:
                result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.")
            if not prog.luminosity_allocated:
                result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.")
            if not prog.trigger_set:
                result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.")
            if not state.selected_channel:
                result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.")

        elif a == ActionType.BUILD_INVARIANT_MASS:
            if not prog.collisions_collected:
                result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.")
            if not prog.tracks_reconstructed:
                result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.")

        elif a == ActionType.SUBTRACT_BACKGROUND:
            if not prog.invariant_mass_built:
                result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.")

        elif a == ActionType.FIT_RESONANCE:
            if not prog.invariant_mass_built:
                result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.")

        elif a == ActionType.MEASURE_ANGULAR:
            if not (prog.resonance_fitted or prog.bump_scanned):
                result.add(
                    ViolationCode.PREREQ_MISSING,
                    "Identify a peak (fit or bump scan) before angular analysis.",
                )

        elif a == ActionType.ESTIMATE_SIGNIFICANCE:
            if not prog.collisions_collected:
                result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.")

        elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
            if not prog.resonance_fitted and not prog.bump_scanned:
                result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.")
            if not prog.significance_estimated:
                result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.")

        # ── parameter & search-window validation (soft) ────────────
        if a == ActionType.SELECT_CHANNEL:
            channel = action.parameters.get("channel")
            if channel:
                try:
                    DetectorChannel(channel)
                except ValueError:
                    result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True)

        if a == ActionType.SET_TRIGGER:
            trig = action.parameters.get("trigger")
            if trig:
                try:
                    TriggerType(trig)
                except ValueError:
                    result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True)

        if a == ActionType.BUILD_INVARIANT_MASS:
            window = action.parameters.get("mass_window_gev")
            if window and len(window) == 2:
                lo, hi = float(window[0]), float(window[1])
                if hi <= lo:
                    result.add(
                        ViolationCode.INVALID_PARAMS,
                        f"Mass window [{lo}, {hi}] is non-positive.",
                        soft=True,
                    )
                if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]:
                    result.add(
                        ViolationCode.OUT_OF_WINDOW,
                        f"Histogram window [{lo}, {hi}] is outside the task search window "
                        f"{self.mass_search_window_gev}.",
                        soft=True,
                    )

        # ── redundancy (soft) ─────────────────────────────────────
        if a == ActionType.CONFIGURE_BEAM and prog.beam_configured:
            result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True)
        if a == ActionType.SELECT_CHANNEL and prog.channel_selected:
            result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True)
        if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed:
            result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True)
        if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated:
            result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True)

        # ── claim sanity ──────────────────────────────────────────
        if a == ActionType.SUBMIT_DISCOVERY_CLAIM:
            claim = action.parameters.get("claim") or {}
            mass = claim.get("mass_estimate_gev")
            if mass is None:
                result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.")
            else:
                try:
                    m = float(mass)
                except Exception:
                    result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.")
                else:
                    lo, hi = self.mass_search_window_gev
                    if not (lo <= m <= hi):
                        result.add(
                            ViolationCode.INVALID_CLAIM,
                            f"Claim mass {m} outside search window [{lo}, {hi}].",
                            soft=True,
                        )
            if claim.get("significance_sigma") is None:
                result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True)

        return result


__all__ = ["RuleResult", "RulesEngine", "ViolationCode"]