cernenv-trainer / models.py
anugrah55's picture
Update CERNenv Space
5f78183 verified
"""
Data models for CERNenv: an LHC (Large Hadron Collider) style particle
physics discovery POMDP (Partially Observable Markov Decision Process).
The agent is a Large Language Model (LLM) acting as a high-energy physicist.
Each step it picks one structured action (configure beams, allocate
luminosity, run a trigger, fit a spectrum, request systematics, submit a
discovery claim, etc.) and receives a noisy detector-style observation.
The latent particle and detector parameters are the hidden ground truth.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from openenv.core.env_server.types import Action, Observation
# ── Action vocabulary ───────────────────────────────────────────────────────
class ActionType(str, Enum):
# ── Beam & data acquisition (DAQ) ─────────────────────────────────
CONFIGURE_BEAM = "configure_beam"
ALLOCATE_LUMINOSITY = "allocate_luminosity"
SET_TRIGGER = "set_trigger"
COLLECT_COLLISIONS = "collect_collisions"
# ── Reconstruction & calibration ─────────────────────────────────
CALIBRATE_DETECTOR = "calibrate_detector"
RECONSTRUCT_TRACKS = "reconstruct_tracks"
SELECT_CHANNEL = "select_channel"
# ── Analysis ──────────────────────────────────────────────────────
BUILD_INVARIANT_MASS = "build_invariant_mass"
SUBTRACT_BACKGROUND = "subtract_background"
FIT_RESONANCE = "fit_resonance"
SCAN_BUMP = "scan_bump"
MEASURE_ANGULAR = "measure_angular"
ESTIMATE_SIGNIFICANCE = "estimate_significance"
# ── Systematics & meta ───────────────────────────────────────────
REQUEST_SYSTEMATICS = "request_systematics"
REQUEST_THEORY_REVIEW = "request_theory_review"
# ── Final ─────────────────────────────────────────────────────────
SUBMIT_DISCOVERY_CLAIM = "submit_discovery_claim"
DAQ_ACTIONS = frozenset({
ActionType.CONFIGURE_BEAM,
ActionType.ALLOCATE_LUMINOSITY,
ActionType.SET_TRIGGER,
ActionType.COLLECT_COLLISIONS,
})
RECO_ACTIONS = frozenset({
ActionType.CALIBRATE_DETECTOR,
ActionType.RECONSTRUCT_TRACKS,
ActionType.SELECT_CHANNEL,
})
ANALYSIS_ACTIONS = frozenset({
ActionType.BUILD_INVARIANT_MASS,
ActionType.SUBTRACT_BACKGROUND,
ActionType.FIT_RESONANCE,
ActionType.SCAN_BUMP,
ActionType.MEASURE_ANGULAR,
ActionType.ESTIMATE_SIGNIFICANCE,
})
META_ACTIONS = frozenset({
ActionType.REQUEST_SYSTEMATICS,
ActionType.REQUEST_THEORY_REVIEW,
ActionType.SUBMIT_DISCOVERY_CLAIM,
})
# ── Detector channels & physics primitives ────────────────────────────────
class DetectorChannel(str, Enum):
"""Final-state decay channel the agent reconstructs in.
Channels affect signal acceptance and background composition. Picking a
channel where the true particle does not decay yields low signal yield
no matter how much luminosity is collected — this is intentional.
"""
DIPHOTON = "diphoton" # γγ
DILEPTON_EE = "dilepton_ee" # e+ e-
DILEPTON_MUMU = "dilepton_mumu" # μ+ μ-
DIJET = "dijet" # jj
FOUR_LEPTON = "four_lepton" # 4ℓ
BB = "bb" # b b-bar
class TriggerType(str, Enum):
"""Hardware-level event selection."""
LOW_PT = "low_pt" # broad acceptance, lots of background
HIGH_PT = "high_pt" # high-mass focus, lower QCD
DIPHOTON_HLT = "diphoton_hlt"
DILEPTON_HLT = "dilepton_hlt"
JET_HLT = "jet_hlt"
class BeamEnergy(str, Enum):
"""LHC-style center-of-mass energies (TeV)."""
E_7 = "7TeV"
E_8 = "8TeV"
E_13 = "13TeV"
E_14 = "14TeV"
# ── Tool / instrument registry (for prompts and tool-fit reward) ──────────
class ToolCategory(str, Enum):
DAQ = "daq"
RECONSTRUCTION = "reconstruction"
CALIBRATION = "calibration"
ANALYSIS = "analysis"
STATISTICS = "statistics"
SYSTEMATICS = "systematics"
class ToolSpec(BaseModel):
name: str
category: ToolCategory
description: str = ""
typical_runtime_hours: float = 0.5
typical_cost_musd: float = 0.0 # in millions of USD (compute / beam time proxy)
requires_gpu: bool = False
channels: List[str] = Field(default_factory=list)
TOOL_REGISTRY: Dict[str, ToolSpec] = {
"ATLAS_HLT": ToolSpec(
name="ATLAS_HLT",
category=ToolCategory.DAQ,
description="ATLAS High-Level Trigger system for online event selection",
typical_runtime_hours=0.0,
channels=["diphoton", "dilepton_ee", "dilepton_mumu", "four_lepton", "dijet", "bb"],
),
"CMS_HLT": ToolSpec(
name="CMS_HLT",
category=ToolCategory.DAQ,
description="CMS High-Level Trigger system",
typical_runtime_hours=0.0,
channels=["diphoton", "dilepton_ee", "dilepton_mumu", "four_lepton", "dijet", "bb"],
),
"GEANT4": ToolSpec(
name="GEANT4",
category=ToolCategory.RECONSTRUCTION,
description="Detector simulation toolkit for full event reconstruction",
typical_runtime_hours=1.0,
typical_cost_musd=0.05,
requires_gpu=False,
),
"Athena": ToolSpec(
name="Athena",
category=ToolCategory.RECONSTRUCTION,
description="ATLAS reconstruction framework",
typical_runtime_hours=0.8,
),
"CMSSW": ToolSpec(
name="CMSSW",
category=ToolCategory.RECONSTRUCTION,
description="CMS reconstruction software",
typical_runtime_hours=0.8,
),
"ECAL_calibration": ToolSpec(
name="ECAL_calibration",
category=ToolCategory.CALIBRATION,
description="Electromagnetic calorimeter energy-scale calibration",
typical_runtime_hours=0.3,
),
"Tracker_alignment": ToolSpec(
name="Tracker_alignment",
category=ToolCategory.CALIBRATION,
description="Inner tracker alignment for momentum precision",
typical_runtime_hours=0.4,
),
"ROOT_RooFit": ToolSpec(
name="ROOT_RooFit",
category=ToolCategory.ANALYSIS,
description="Maximum-likelihood spectrum fitting toolkit",
typical_runtime_hours=0.2,
),
"MadGraph": ToolSpec(
name="MadGraph",
category=ToolCategory.ANALYSIS,
description="Matrix-element generator for signal+background templates",
typical_runtime_hours=1.5,
typical_cost_musd=0.02,
),
"Pythia8": ToolSpec(
name="Pythia8",
category=ToolCategory.ANALYSIS,
description="Parton-shower and hadronisation generator",
typical_runtime_hours=0.5,
),
"BumpHunter": ToolSpec(
name="BumpHunter",
category=ToolCategory.STATISTICS,
description="Sliding-window local-significance bump-hunting algorithm",
typical_runtime_hours=0.1,
),
"CLs_fit": ToolSpec(
name="CLs_fit",
category=ToolCategory.STATISTICS,
description="Modified-frequentist CLs limits and significance",
typical_runtime_hours=0.1,
),
"Asimov_significance": ToolSpec(
name="Asimov_significance",
category=ToolCategory.STATISTICS,
description="Asymptotic significance from Asimov dataset",
typical_runtime_hours=0.05,
),
"JES_systematics": ToolSpec(
name="JES_systematics",
category=ToolCategory.SYSTEMATICS,
description="Jet energy-scale systematic study",
typical_runtime_hours=0.4,
),
"Luminosity_calibration": ToolSpec(
name="Luminosity_calibration",
category=ToolCategory.SYSTEMATICS,
description="Van der Meer scan luminosity calibration",
typical_runtime_hours=0.3,
),
}
# ── Action schema ──────────────────────────────────────────────────────────
class ExperimentAction(Action):
"""One structured experimental step at the LHC."""
action_type: ActionType = Field(
...,
description=(
"Discrete LHC pipeline step. The environment enforces physics "
"prerequisites: you cannot fit a spectrum before collecting data, "
"or claim a discovery before estimating significance."
),
)
method: Optional[str] = Field(
None,
description=(
"Optional named instrument or framework (e.g. 'ROOT_RooFit', "
"'BumpHunter', 'Pythia8'). Affects cost, runtime, and tool-fit reward."
),
)
parameters: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Action-specific settings such as beam energy, integrated luminosity "
"(fb^-1), trigger selection, decay channel, mass window, fit model."
),
)
justification: Optional[str] = Field(
None,
description="Short scientific rationale for picking this step now.",
)
confidence: float = Field(
0.5, ge=0.0, le=1.0,
description="Agent confidence in the chosen step.",
)
# ── Outputs ────────────────────────────────────────────────────────────────
class OutputType(str, Enum):
BEAM_CONFIG = "beam_config"
LUMINOSITY_LOG = "luminosity_log"
TRIGGER_REPORT = "trigger_report"
COLLISION_BATCH = "collision_batch"
CALIBRATION_REPORT = "calibration_report"
RECONSTRUCTION = "reconstruction"
CHANNEL_SELECTION = "channel_selection"
INVARIANT_MASS_HIST = "invariant_mass_hist"
BACKGROUND_SUBTRACTION = "background_subtraction"
FIT_RESULT = "fit_result"
BUMP_SCAN = "bump_scan"
ANGULAR_RESULT = "angular_result"
SIGNIFICANCE = "significance"
SYSTEMATICS_REPORT = "systematics_report"
THEORY_REVIEW = "theory_review"
DISCOVERY_CLAIM = "discovery_claim"
FAILURE_REPORT = "failure_report"
class IntermediateOutput(BaseModel):
"""A single noisy detector or analysis artifact."""
output_type: OutputType
step_index: int
success: bool = True
quality_score: float = Field(1.0, ge=0.0, le=1.0)
summary: str = ""
data: Dict[str, Any] = Field(default_factory=dict)
uncertainty: float = Field(0.0, ge=0.0, le=1.0)
warnings: List[str] = Field(default_factory=list)
artifacts_available: List[str] = Field(default_factory=list)
# ── Observable state components ───────────────────────────────────────────
class ResourceUsage(BaseModel):
"""Agent-visible resource counters."""
budget_used_musd: float = 0.0
budget_remaining_musd: float = 100.0
luminosity_used_fb: float = 0.0
luminosity_remaining_fb: float = 300.0
time_used_days: float = 0.0
time_remaining_days: float = 365.0
compute_hours_used: float = 0.0
class PipelineStepRecord(BaseModel):
step_index: int
action_type: ActionType
method: Optional[str] = None
parameters: Dict[str, Any] = Field(default_factory=dict)
output_summary: str = ""
output_type: OutputType
success: bool = True
quality_score: float = 1.0
cost_musd: float = 0.0
luminosity_cost_fb: float = 0.0
time_cost_days: float = 0.0
class PaperReference(BaseModel):
title: str
citation: Optional[str] = None
doi: Optional[str] = None
arxiv_id: Optional[str] = None
url: Optional[str] = None
class ExpectedFinding(BaseModel):
finding: str
category: str = "claim"
keywords: List[str] = Field(default_factory=list)
class TaskSpec(BaseModel):
"""The physics question the agent is given for this episode."""
problem_statement: str = "Discover and characterise an unknown resonance."
target_collider: str = "LHC"
beam_energy_options: List[str] = Field(
default_factory=lambda: [e.value for e in BeamEnergy],
)
available_channels: List[str] = Field(
default_factory=lambda: [c.value for c in DetectorChannel],
)
available_triggers: List[str] = Field(
default_factory=lambda: [t.value for t in TriggerType],
)
available_tools: List[str] = Field(
default_factory=lambda: list(TOOL_REGISTRY.keys()),
)
mass_search_window_gev: List[float] = Field(default_factory=lambda: [50.0, 1000.0])
budget_limit_musd: float = 100.0
luminosity_budget_fb: float = 300.0
time_limit_days: float = 365.0
prior_observations: List[str] = Field(default_factory=list)
success_criteria: List[str] = Field(default_factory=list)
paper_references: List[PaperReference] = Field(default_factory=list)
expected_findings: List[ExpectedFinding] = Field(default_factory=list)
difficulty: str = "medium"
class DiscoveryClaim(BaseModel):
"""Structured final claim graded against hidden truth."""
claim: str = ""
mass_estimate_gev: Optional[float] = None
mass_uncertainty_gev: Optional[float] = None
width_estimate_gev: Optional[float] = None
significance_sigma: Optional[float] = None
decay_channel: Optional[str] = None
spin_hypothesis: Optional[int] = None # 0, 1, 2
parity: Optional[str] = None # "+", "-"
cross_section_fb: Optional[float] = None
confidence: float = Field(0.5, ge=0.0, le=1.0)
evidence_steps: List[int] = Field(default_factory=list)
class CollisionObservation(Observation):
"""Full observable state returned to the agent each step.
Excludes the hidden particle truth and hidden detector systematics.
"""
task: TaskSpec = Field(default_factory=TaskSpec)
step_index: int = 0
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
available_channels: List[str] = Field(default_factory=list)
available_triggers: List[str] = Field(default_factory=list)
available_tools: List[str] = Field(default_factory=list)
resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
latest_output: Optional[IntermediateOutput] = None
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
candidate_masses_gev: List[float] = Field(default_factory=list)
candidate_significances: List[float] = Field(default_factory=list)
selected_channel: Optional[str] = None
selected_beam_energy: Optional[str] = None
cumulative_significance: float = 0.0
uncertainty_summary: Dict[str, float] = Field(default_factory=dict)
rule_violations: List[str] = Field(default_factory=list)
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
# ── Agent-facing prompt helpers ───────────────────────────────────────────
AGENT_ACTION_GUIDANCE: Dict[ActionType, str] = {
ActionType.CONFIGURE_BEAM: (
"Pick the LHC center-of-mass energy. Higher energy reaches heavier "
"resonances but costs more per fb^-1. Required before collecting data."
),
ActionType.ALLOCATE_LUMINOSITY: (
"Schedule a chunk of integrated luminosity (fb^-1). More luminosity "
"means more events but uses budget and time. Required before collecting."
),
ActionType.SET_TRIGGER: (
"Choose a hardware/HLT trigger. Match the trigger to the channel of "
"interest; mismatched triggers throw away signal."
),
ActionType.COLLECT_COLLISIONS: (
"Run the experiment. Returns a noisy raw event count plus background "
"estimate, conditioned on beam, luminosity, trigger, and channel."
),
ActionType.CALIBRATE_DETECTOR: (
"Apply ECAL/tracker calibration. Reduces systematic uncertainty; "
"neglecting it inflates fit uncertainty later."
),
ActionType.RECONSTRUCT_TRACKS: (
"Reconstruct charged-particle tracks and physics objects. Required "
"before any analysis-level step."
),
ActionType.SELECT_CHANNEL: (
"Pick the decay channel to study (γγ, ℓℓ, jj, 4ℓ, bb). Wrong channel "
"= small signal acceptance regardless of luminosity."
),
ActionType.BUILD_INVARIANT_MASS: (
"Construct the invariant-mass histogram in the chosen channel and "
"mass window."
),
ActionType.SUBTRACT_BACKGROUND: (
"Fit a smooth background model and subtract it to expose any peak."
),
ActionType.FIT_RESONANCE: (
"Fit a Breit-Wigner / Crystal Ball line shape. Returns mass, width, "
"and statistical uncertainty."
),
ActionType.SCAN_BUMP: (
"Run a sliding-window bump hunt over the mass window. Reports the "
"most-significant candidate region."
),
ActionType.MEASURE_ANGULAR: (
"Measure decay angular distribution to constrain spin/parity. "
"Useful only after a peak is identified."
),
ActionType.ESTIMATE_SIGNIFICANCE: (
"Compute the statistical significance of a candidate signal in σ. "
"Required before claiming a discovery."
),
ActionType.REQUEST_SYSTEMATICS: (
"Run a systematics study (JES, luminosity, calibration). Improves "
"uncertainty estimates and reduces overconfidence penalty."
),
ActionType.REQUEST_THEORY_REVIEW: (
"Ask a theorist sub-agent to review the evidence; small extra signal "
"but not a substitute for missing data."
),
ActionType.SUBMIT_DISCOVERY_CLAIM: (
"Submit a structured discovery claim. Graded on mass calibration, "
"significance, channel, spin hypothesis, and overconfidence."
),
}
AGENT_ENVIRONMENT_RULES: List[str] = [
"Each successful action returns summarized evidence; do not repeat steps.",
"Hard prerequisites are enforced: data collection requires beam+luminosity+trigger; "
"analysis requires reconstruction and a chosen channel.",
"A discovery claim requires a fitted resonance and an estimated significance.",
"Tools listed in available_tools are pre-filtered for this episode; prefer them.",
"Submitting an overconfident wrong claim is heavily penalised.",
]
def build_agent_system_prompt() -> str:
lines = [
"You are an expert high-energy physicist running an analysis at the LHC.",
"",
"At each turn you observe the experiment state and pick one structured next step",
"to maximise the probability of correctly characterising a hidden resonance.",
"",
"Environment rules:",
]
lines.extend(f" - {rule}" for rule in AGENT_ENVIRONMENT_RULES)
lines.append("")
lines.append("Action guidance:")
lines.extend(
f" - {a.value}: {AGENT_ACTION_GUIDANCE[a]}" for a in ActionType
)
lines.extend([
"",
"Respond with ONLY a single valid JSON object, no extra prose:",
'{"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}',
"",
"For submit_discovery_claim, structure parameters['claim'] as:",
'{"mass_estimate_gev": 125.0, "mass_uncertainty_gev": 0.5, "width_estimate_gev": 0.004,'
' "significance_sigma": 5.2, "decay_channel": "diphoton", "spin_hypothesis": 0,'
' "parity": "+", "cross_section_fb": 50.0, "confidence": 0.9}',
])
return "\n".join(lines)
def build_agent_observation_context(
obs: CollisionObservation,
*,
max_tools: int = 6,
max_channels: int = 4,
) -> str:
parts: List[str] = []
parts.append(
f"Mass search window: [{obs.task.mass_search_window_gev[0]:.0f}, "
f"{obs.task.mass_search_window_gev[1]:.0f}] GeV; "
f"difficulty={obs.task.difficulty}."
)
chans = list(dict.fromkeys(obs.available_channels or obs.task.available_channels))
if chans:
parts.append("Available channels: " + ", ".join(chans[:max_channels]))
tools = list(dict.fromkeys(obs.available_tools or obs.task.available_tools))
if tools:
parts.append("Available tools: " + ", ".join(tools[:max_tools]))
if obs.selected_channel:
parts.append(f"Selected channel: {obs.selected_channel}")
if obs.selected_beam_energy:
parts.append(f"Beam energy: {obs.selected_beam_energy}")
if obs.candidate_masses_gev:
masses = [f"{m:.1f}" for m in obs.candidate_masses_gev[:3]]
sigmas = [f"{s:.1f}" for s in obs.candidate_significances[:3]]
parts.append(
"Candidate peaks (GeV / σ): "
+ ", ".join(f"{m}/{s}" for m, s in zip(masses, sigmas))
)
return "\n".join(parts)
__all__ = [
"ActionType",
"DAQ_ACTIONS",
"RECO_ACTIONS",
"ANALYSIS_ACTIONS",
"META_ACTIONS",
"DetectorChannel",
"TriggerType",
"BeamEnergy",
"ToolCategory",
"ToolSpec",
"TOOL_REGISTRY",
"ExperimentAction",
"OutputType",
"IntermediateOutput",
"ResourceUsage",
"PipelineStepRecord",
"PaperReference",
"ExpectedFinding",
"TaskSpec",
"DiscoveryClaim",
"CollisionObservation",
"AGENT_ACTION_GUIDANCE",
"AGENT_ENVIRONMENT_RULES",
"build_agent_system_prompt",
"build_agent_observation_context",
]