| """RulesEngine for CERNenv.
|
|
|
| Validates an incoming ``ExperimentAction`` against the current latent state
|
| *before* it is executed. Rule violations are reported back as warnings on the
|
| observation and feed into the per-step penalty in the reward function.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from dataclasses import dataclass, field
|
| from enum import Enum
|
| from typing import List, Optional
|
|
|
| from models import (
|
| ActionType,
|
| DetectorChannel,
|
| ExperimentAction,
|
| TriggerType,
|
| )
|
|
|
| from server.simulator.latent_state import FullLatentState
|
|
|
|
|
| class ViolationCode(str, Enum):
|
| PREREQ_MISSING = "prerequisite_missing"
|
| BUDGET_EXHAUSTED = "budget_exhausted"
|
| LUMI_EXHAUSTED = "luminosity_exhausted"
|
| TIME_EXHAUSTED = "time_exhausted"
|
| REDUNDANT = "redundant"
|
| INVALID_PARAMS = "invalid_parameters"
|
| INVALID_CLAIM = "invalid_claim"
|
| CHANNEL_MISMATCH = "channel_mismatch"
|
| OUT_OF_WINDOW = "out_of_search_window"
|
|
|
|
|
| @dataclass
|
| class RuleResult:
|
| allowed: bool
|
| violations: List[ViolationCode] = field(default_factory=list)
|
| messages: List[str] = field(default_factory=list)
|
| soft_violations: List[ViolationCode] = field(default_factory=list)
|
|
|
| def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None:
|
| self.messages.append(msg)
|
| if soft:
|
| self.soft_violations.append(code)
|
| else:
|
| self.violations.append(code)
|
| self.allowed = False
|
|
|
|
|
| class RulesEngine:
|
| """Stateless validator (state is passed in)."""
|
|
|
| def __init__(
|
| self,
|
| mass_search_window_gev: tuple[float, float] = (50.0, 1000.0),
|
| ) -> None:
|
| self.mass_search_window_gev = mass_search_window_gev
|
|
|
|
|
|
|
| def validate(
|
| self,
|
| action: ExperimentAction,
|
| state: FullLatentState,
|
| ) -> RuleResult:
|
| result = RuleResult(allowed=True)
|
|
|
|
|
| if state.resources.budget_exhausted:
|
| result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.")
|
| if state.resources.time_exhausted:
|
| result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.")
|
|
|
| if (
|
| state.resources.luminosity_exhausted
|
| and action.action_type in {
|
| ActionType.ALLOCATE_LUMINOSITY,
|
| ActionType.COLLECT_COLLISIONS,
|
| }
|
| ):
|
| result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.")
|
|
|
| if not result.allowed:
|
| return result
|
|
|
| a = action.action_type
|
| prog = state.progress
|
|
|
|
|
| if a == ActionType.COLLECT_COLLISIONS:
|
| if not prog.beam_configured:
|
| result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.")
|
| if not prog.luminosity_allocated:
|
| result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.")
|
| if not prog.trigger_set:
|
| result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.")
|
| if not state.selected_channel:
|
| result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.")
|
|
|
| elif a == ActionType.BUILD_INVARIANT_MASS:
|
| if not prog.collisions_collected:
|
| result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.")
|
| if not prog.tracks_reconstructed:
|
| result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.")
|
|
|
| elif a == ActionType.SUBTRACT_BACKGROUND:
|
| if not prog.invariant_mass_built:
|
| result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.")
|
|
|
| elif a == ActionType.FIT_RESONANCE:
|
| if not prog.invariant_mass_built:
|
| result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.")
|
|
|
| elif a == ActionType.MEASURE_ANGULAR:
|
| if not (prog.resonance_fitted or prog.bump_scanned):
|
| result.add(
|
| ViolationCode.PREREQ_MISSING,
|
| "Identify a peak (fit or bump scan) before angular analysis.",
|
| )
|
|
|
| elif a == ActionType.ESTIMATE_SIGNIFICANCE:
|
| if not prog.collisions_collected:
|
| result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.")
|
|
|
| elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| if not prog.resonance_fitted and not prog.bump_scanned:
|
| result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.")
|
| if not prog.significance_estimated:
|
| result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.")
|
|
|
|
|
| if a == ActionType.SELECT_CHANNEL:
|
| channel = action.parameters.get("channel")
|
| if channel:
|
| try:
|
| DetectorChannel(channel)
|
| except ValueError:
|
| result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True)
|
|
|
| if a == ActionType.SET_TRIGGER:
|
| trig = action.parameters.get("trigger")
|
| if trig:
|
| try:
|
| TriggerType(trig)
|
| except ValueError:
|
| result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True)
|
|
|
| if a == ActionType.BUILD_INVARIANT_MASS:
|
| window = action.parameters.get("mass_window_gev")
|
| if window and len(window) == 2:
|
| lo, hi = float(window[0]), float(window[1])
|
| if hi <= lo:
|
| result.add(
|
| ViolationCode.INVALID_PARAMS,
|
| f"Mass window [{lo}, {hi}] is non-positive.",
|
| soft=True,
|
| )
|
| if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]:
|
| result.add(
|
| ViolationCode.OUT_OF_WINDOW,
|
| f"Histogram window [{lo}, {hi}] is outside the task search window "
|
| f"{self.mass_search_window_gev}.",
|
| soft=True,
|
| )
|
|
|
|
|
| if a == ActionType.CONFIGURE_BEAM and prog.beam_configured:
|
| result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True)
|
| if a == ActionType.SELECT_CHANNEL and prog.channel_selected:
|
| result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True)
|
| if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed:
|
| result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True)
|
| if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated:
|
| result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True)
|
|
|
|
|
| if a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| claim = action.parameters.get("claim") or {}
|
| mass = claim.get("mass_estimate_gev")
|
| if mass is None:
|
| result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.")
|
| else:
|
| try:
|
| m = float(mass)
|
| except Exception:
|
| result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.")
|
| else:
|
| lo, hi = self.mass_search_window_gev
|
| if not (lo <= m <= hi):
|
| result.add(
|
| ViolationCode.INVALID_CLAIM,
|
| f"Claim mass {m} outside search window [{lo}, {hi}].",
|
| soft=True,
|
| )
|
| if claim.get("significance_sigma") is None:
|
| result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True)
|
|
|
| return result
|
|
|
|
|
| __all__ = ["RuleResult", "RulesEngine", "ViolationCode"]
|
|
|