Spaces:
Paused
Paused
| """RulesEngine for CERNenv. | |
| Validates an incoming ``ExperimentAction`` against the current latent state | |
| *before* it is executed. Rule violations are reported back as warnings on the | |
| observation and feed into the per-step penalty in the reward function. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import List, Optional | |
| from models import ( | |
| ActionType, | |
| DetectorChannel, | |
| ExperimentAction, | |
| TriggerType, | |
| ) | |
| from server.simulator.latent_state import FullLatentState | |
| class ViolationCode(str, Enum): | |
| PREREQ_MISSING = "prerequisite_missing" | |
| BUDGET_EXHAUSTED = "budget_exhausted" | |
| LUMI_EXHAUSTED = "luminosity_exhausted" | |
| TIME_EXHAUSTED = "time_exhausted" | |
| REDUNDANT = "redundant" | |
| INVALID_PARAMS = "invalid_parameters" | |
| INVALID_CLAIM = "invalid_claim" | |
| CHANNEL_MISMATCH = "channel_mismatch" | |
| OUT_OF_WINDOW = "out_of_search_window" | |
| class RuleResult: | |
| allowed: bool | |
| violations: List[ViolationCode] = field(default_factory=list) | |
| messages: List[str] = field(default_factory=list) | |
| soft_violations: List[ViolationCode] = field(default_factory=list) | |
| def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None: | |
| self.messages.append(msg) | |
| if soft: | |
| self.soft_violations.append(code) | |
| else: | |
| self.violations.append(code) | |
| self.allowed = False | |
| class RulesEngine: | |
| """Stateless validator (state is passed in).""" | |
| def __init__( | |
| self, | |
| mass_search_window_gev: tuple[float, float] = (50.0, 1000.0), | |
| ) -> None: | |
| self.mass_search_window_gev = mass_search_window_gev | |
| # ── Public API ───────────────────────────────────────────────────── | |
| def validate( | |
| self, | |
| action: ExperimentAction, | |
| state: FullLatentState, | |
| ) -> RuleResult: | |
| result = RuleResult(allowed=True) | |
| # ── resource gating (hard) ──────────────────────────────── | |
| if state.resources.budget_exhausted: | |
| result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.") | |
| if state.resources.time_exhausted: | |
| result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.") | |
| # luminosity exhaustion only blocks DAQ-style actions | |
| if ( | |
| state.resources.luminosity_exhausted | |
| and action.action_type in { | |
| ActionType.ALLOCATE_LUMINOSITY, | |
| ActionType.COLLECT_COLLISIONS, | |
| } | |
| ): | |
| result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.") | |
| if not result.allowed: | |
| return result | |
| a = action.action_type | |
| prog = state.progress | |
| # ── prerequisites ────────────────────────────────────────── | |
| if a == ActionType.COLLECT_COLLISIONS: | |
| if not prog.beam_configured: | |
| result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.") | |
| if not prog.luminosity_allocated: | |
| result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.") | |
| if not prog.trigger_set: | |
| result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.") | |
| if not state.selected_channel: | |
| result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.") | |
| elif a == ActionType.BUILD_INVARIANT_MASS: | |
| if not prog.collisions_collected: | |
| result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.") | |
| if not prog.tracks_reconstructed: | |
| result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.") | |
| elif a == ActionType.SUBTRACT_BACKGROUND: | |
| if not prog.invariant_mass_built: | |
| result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.") | |
| elif a == ActionType.FIT_RESONANCE: | |
| if not prog.invariant_mass_built: | |
| result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.") | |
| elif a == ActionType.MEASURE_ANGULAR: | |
| if not (prog.resonance_fitted or prog.bump_scanned): | |
| result.add( | |
| ViolationCode.PREREQ_MISSING, | |
| "Identify a peak (fit or bump scan) before angular analysis.", | |
| ) | |
| elif a == ActionType.ESTIMATE_SIGNIFICANCE: | |
| if not prog.collisions_collected: | |
| result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.") | |
| elif a == ActionType.SUBMIT_DISCOVERY_CLAIM: | |
| if not prog.resonance_fitted and not prog.bump_scanned: | |
| result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.") | |
| if not prog.significance_estimated: | |
| result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.") | |
| # ── parameter & search-window validation (soft) ──────────── | |
| if a == ActionType.SELECT_CHANNEL: | |
| channel = action.parameters.get("channel") | |
| if channel: | |
| try: | |
| DetectorChannel(channel) | |
| except ValueError: | |
| result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True) | |
| if a == ActionType.SET_TRIGGER: | |
| trig = action.parameters.get("trigger") | |
| if trig: | |
| try: | |
| TriggerType(trig) | |
| except ValueError: | |
| result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True) | |
| if a == ActionType.BUILD_INVARIANT_MASS: | |
| window = action.parameters.get("mass_window_gev") | |
| if window and len(window) == 2: | |
| lo, hi = float(window[0]), float(window[1]) | |
| if hi <= lo: | |
| result.add( | |
| ViolationCode.INVALID_PARAMS, | |
| f"Mass window [{lo}, {hi}] is non-positive.", | |
| soft=True, | |
| ) | |
| if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]: | |
| result.add( | |
| ViolationCode.OUT_OF_WINDOW, | |
| f"Histogram window [{lo}, {hi}] is outside the task search window " | |
| f"{self.mass_search_window_gev}.", | |
| soft=True, | |
| ) | |
| # ── redundancy (soft) ───────────────────────────────────── | |
| if a == ActionType.CONFIGURE_BEAM and prog.beam_configured: | |
| result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True) | |
| if a == ActionType.SELECT_CHANNEL and prog.channel_selected: | |
| result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True) | |
| if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed: | |
| result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True) | |
| if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated: | |
| result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True) | |
| # ── claim sanity ────────────────────────────────────────── | |
| if a == ActionType.SUBMIT_DISCOVERY_CLAIM: | |
| claim = action.parameters.get("claim") or {} | |
| mass = claim.get("mass_estimate_gev") | |
| if mass is None: | |
| result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.") | |
| else: | |
| try: | |
| m = float(mass) | |
| except Exception: | |
| result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.") | |
| else: | |
| lo, hi = self.mass_search_window_gev | |
| if not (lo <= m <= hi): | |
| result.add( | |
| ViolationCode.INVALID_CLAIM, | |
| f"Claim mass {m} outside search window [{lo}, {hi}].", | |
| soft=True, | |
| ) | |
| if claim.get("significance_sigma") is None: | |
| result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True) | |
| return result | |
| __all__ = ["RuleResult", "RulesEngine", "ViolationCode"] | |