cernenv / server /rules /engine.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""RulesEngine for CERNenv.
Validates an incoming ``ExperimentAction`` against the current latent state
*before* it is executed. Rule violations are reported back as warnings on the
observation and feed into the per-step penalty in the reward function.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from models import (
ActionType,
DetectorChannel,
ExperimentAction,
TriggerType,
)
from server.simulator.latent_state import FullLatentState
class ViolationCode(str, Enum):
PREREQ_MISSING = "prerequisite_missing"
BUDGET_EXHAUSTED = "budget_exhausted"
LUMI_EXHAUSTED = "luminosity_exhausted"
TIME_EXHAUSTED = "time_exhausted"
REDUNDANT = "redundant"
INVALID_PARAMS = "invalid_parameters"
INVALID_CLAIM = "invalid_claim"
CHANNEL_MISMATCH = "channel_mismatch"
OUT_OF_WINDOW = "out_of_search_window"
@dataclass
class RuleResult:
allowed: bool
violations: List[ViolationCode] = field(default_factory=list)
messages: List[str] = field(default_factory=list)
soft_violations: List[ViolationCode] = field(default_factory=list)
def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None:
self.messages.append(msg)
if soft:
self.soft_violations.append(code)
else:
self.violations.append(code)
self.allowed = False
class RulesEngine:
"""Stateless validator (state is passed in)."""
def __init__(
self,
mass_search_window_gev: tuple[float, float] = (50.0, 1000.0),
) -> None:
self.mass_search_window_gev = mass_search_window_gev
# ── Public API ─────────────────────────────────────────────────────
def validate(
self,
action: ExperimentAction,
state: FullLatentState,
) -> RuleResult:
result = RuleResult(allowed=True)
# ── resource gating (hard) ────────────────────────────────
if state.resources.budget_exhausted:
result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.")
if state.resources.time_exhausted:
result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.")
# luminosity exhaustion only blocks DAQ-style actions
if (
state.resources.luminosity_exhausted
and action.action_type in {
ActionType.ALLOCATE_LUMINOSITY,
ActionType.COLLECT_COLLISIONS,
}
):
result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.")
if not result.allowed:
return result
a = action.action_type
prog = state.progress
# ── prerequisites ──────────────────────────────────────────
if a == ActionType.COLLECT_COLLISIONS:
if not prog.beam_configured:
result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.")
if not prog.luminosity_allocated:
result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.")
if not prog.trigger_set:
result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.")
if not state.selected_channel:
result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.")
elif a == ActionType.BUILD_INVARIANT_MASS:
if not prog.collisions_collected:
result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.")
if not prog.tracks_reconstructed:
result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.")
elif a == ActionType.SUBTRACT_BACKGROUND:
if not prog.invariant_mass_built:
result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.")
elif a == ActionType.FIT_RESONANCE:
if not prog.invariant_mass_built:
result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.")
elif a == ActionType.MEASURE_ANGULAR:
if not (prog.resonance_fitted or prog.bump_scanned):
result.add(
ViolationCode.PREREQ_MISSING,
"Identify a peak (fit or bump scan) before angular analysis.",
)
elif a == ActionType.ESTIMATE_SIGNIFICANCE:
if not prog.collisions_collected:
result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.")
elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
if not prog.resonance_fitted and not prog.bump_scanned:
result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.")
if not prog.significance_estimated:
result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.")
# ── parameter & search-window validation (soft) ────────────
if a == ActionType.SELECT_CHANNEL:
channel = action.parameters.get("channel")
if channel:
try:
DetectorChannel(channel)
except ValueError:
result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True)
if a == ActionType.SET_TRIGGER:
trig = action.parameters.get("trigger")
if trig:
try:
TriggerType(trig)
except ValueError:
result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True)
if a == ActionType.BUILD_INVARIANT_MASS:
window = action.parameters.get("mass_window_gev")
if window and len(window) == 2:
lo, hi = float(window[0]), float(window[1])
if hi <= lo:
result.add(
ViolationCode.INVALID_PARAMS,
f"Mass window [{lo}, {hi}] is non-positive.",
soft=True,
)
if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]:
result.add(
ViolationCode.OUT_OF_WINDOW,
f"Histogram window [{lo}, {hi}] is outside the task search window "
f"{self.mass_search_window_gev}.",
soft=True,
)
# ── redundancy (soft) ─────────────────────────────────────
if a == ActionType.CONFIGURE_BEAM and prog.beam_configured:
result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True)
if a == ActionType.SELECT_CHANNEL and prog.channel_selected:
result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True)
if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed:
result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True)
if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated:
result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True)
# ── claim sanity ──────────────────────────────────────────
if a == ActionType.SUBMIT_DISCOVERY_CLAIM:
claim = action.parameters.get("claim") or {}
mass = claim.get("mass_estimate_gev")
if mass is None:
result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.")
else:
try:
m = float(mass)
except Exception:
result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.")
else:
lo, hi = self.mass_search_window_gev
if not (lo <= m <= hi):
result.add(
ViolationCode.INVALID_CLAIM,
f"Claim mass {m} outside search window [{lo}, {hi}].",
soft=True,
)
if claim.get("significance_sigma") is None:
result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True)
return result
__all__ = ["RuleResult", "RulesEngine", "ViolationCode"]