"""RulesEngine for CERNenv. Validates an incoming ``ExperimentAction`` against the current latent state *before* it is executed. Rule violations are reported back as warnings on the observation and feed into the per-step penalty in the reward function. """ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum from typing import List, Optional from models import ( ActionType, DetectorChannel, ExperimentAction, TriggerType, ) from server.simulator.latent_state import FullLatentState class ViolationCode(str, Enum): PREREQ_MISSING = "prerequisite_missing" BUDGET_EXHAUSTED = "budget_exhausted" LUMI_EXHAUSTED = "luminosity_exhausted" TIME_EXHAUSTED = "time_exhausted" REDUNDANT = "redundant" INVALID_PARAMS = "invalid_parameters" INVALID_CLAIM = "invalid_claim" CHANNEL_MISMATCH = "channel_mismatch" OUT_OF_WINDOW = "out_of_search_window" @dataclass class RuleResult: allowed: bool violations: List[ViolationCode] = field(default_factory=list) messages: List[str] = field(default_factory=list) soft_violations: List[ViolationCode] = field(default_factory=list) def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None: self.messages.append(msg) if soft: self.soft_violations.append(code) else: self.violations.append(code) self.allowed = False class RulesEngine: """Stateless validator (state is passed in).""" def __init__( self, mass_search_window_gev: tuple[float, float] = (50.0, 1000.0), ) -> None: self.mass_search_window_gev = mass_search_window_gev # ── Public API ───────────────────────────────────────────────────── def validate( self, action: ExperimentAction, state: FullLatentState, ) -> RuleResult: result = RuleResult(allowed=True) # ── resource gating (hard) ──────────────────────────────── if state.resources.budget_exhausted: result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.") if state.resources.time_exhausted: result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.") # luminosity exhaustion only blocks DAQ-style actions if ( state.resources.luminosity_exhausted and action.action_type in { ActionType.ALLOCATE_LUMINOSITY, ActionType.COLLECT_COLLISIONS, } ): result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.") if not result.allowed: return result a = action.action_type prog = state.progress # ── prerequisites ────────────────────────────────────────── if a == ActionType.COLLECT_COLLISIONS: if not prog.beam_configured: result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.") if not prog.luminosity_allocated: result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.") if not prog.trigger_set: result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.") if not state.selected_channel: result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.") elif a == ActionType.BUILD_INVARIANT_MASS: if not prog.collisions_collected: result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.") if not prog.tracks_reconstructed: result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.") elif a == ActionType.SUBTRACT_BACKGROUND: if not prog.invariant_mass_built: result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.") elif a == ActionType.FIT_RESONANCE: if not prog.invariant_mass_built: result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.") elif a == ActionType.MEASURE_ANGULAR: if not (prog.resonance_fitted or prog.bump_scanned): result.add( ViolationCode.PREREQ_MISSING, "Identify a peak (fit or bump scan) before angular analysis.", ) elif a == ActionType.ESTIMATE_SIGNIFICANCE: if not prog.collisions_collected: result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.") elif a == ActionType.SUBMIT_DISCOVERY_CLAIM: if not prog.resonance_fitted and not prog.bump_scanned: result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.") if not prog.significance_estimated: result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.") # ── parameter & search-window validation (soft) ──────────── if a == ActionType.SELECT_CHANNEL: channel = action.parameters.get("channel") if channel: try: DetectorChannel(channel) except ValueError: result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True) if a == ActionType.SET_TRIGGER: trig = action.parameters.get("trigger") if trig: try: TriggerType(trig) except ValueError: result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True) if a == ActionType.BUILD_INVARIANT_MASS: window = action.parameters.get("mass_window_gev") if window and len(window) == 2: lo, hi = float(window[0]), float(window[1]) if hi <= lo: result.add( ViolationCode.INVALID_PARAMS, f"Mass window [{lo}, {hi}] is non-positive.", soft=True, ) if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]: result.add( ViolationCode.OUT_OF_WINDOW, f"Histogram window [{lo}, {hi}] is outside the task search window " f"{self.mass_search_window_gev}.", soft=True, ) # ── redundancy (soft) ───────────────────────────────────── if a == ActionType.CONFIGURE_BEAM and prog.beam_configured: result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True) if a == ActionType.SELECT_CHANNEL and prog.channel_selected: result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True) if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed: result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True) if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated: result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True) # ── claim sanity ────────────────────────────────────────── if a == ActionType.SUBMIT_DISCOVERY_CLAIM: claim = action.parameters.get("claim") or {} mass = claim.get("mass_estimate_gev") if mass is None: result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.") else: try: m = float(mass) except Exception: result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.") else: lo, hi = self.mass_search_window_gev if not (lo <= m <= hi): result.add( ViolationCode.INVALID_CLAIM, f"Claim mass {m} outside search window [{lo}, {hi}].", soft=True, ) if claim.get("significance_sigma") is None: result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True) return result __all__ = ["RuleResult", "RulesEngine", "ViolationCode"]