"""LLM (Large Language Model) agent that picks the next CERNenv action. The agent renders an observation as a short prompt, asks the LLM for a JSON-formatted ``ExperimentAction``, validates the response, and falls back to a safe default action if parsing fails. This is the unit shared by evaluation and the GRPO (Group-Relative Policy Optimization) training loop. """ from __future__ import annotations import json import re from dataclasses import dataclass from typing import Any, Dict, List, Optional from models import ( ActionType, CollisionObservation, ExperimentAction, build_agent_observation_context, build_agent_system_prompt, ) _VALID_ACTIONS = {a.value for a in ActionType} @dataclass class LLMAgentConfig: """Knobs for prompt formatting and decoding.""" max_history_steps: int = 6 temperature: float = 0.7 max_new_tokens: int = 256 top_p: float = 0.95 def render_history(obs: CollisionObservation, max_steps: int) -> str: if not obs.pipeline_history: return " (none yet — pick a starting action)" lines: List[str] = [] history = obs.pipeline_history[-max_steps:] for rec in history: success = "OK" if rec.success else "FAIL" lines.append( f" step {rec.step_index:>2} {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}" ) return "\n".join(lines) def render_resources(obs: CollisionObservation) -> str: r = obs.resource_usage return ( f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, " f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, " f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left" ) def render_user_prompt( obs: CollisionObservation, config: LLMAgentConfig = LLMAgentConfig(), ) -> str: parts: List[str] = [] parts.append("Task:") parts.append(" " + obs.task.problem_statement.strip()) parts.append("") parts.append("Public state:") parts.append(" " + build_agent_observation_context(obs).replace("\n", "\n ")) parts.append("") parts.append("Resources:") parts.append(" " + render_resources(obs)) parts.append("") parts.append("Recent steps:") parts.append(render_history(obs, max_steps=config.max_history_steps)) if obs.rule_violations: parts.append("") parts.append("Last-step violations: " + ", ".join(obs.rule_violations)) parts.append("") parts.append("Choose ONE next action and respond with a single JSON object.") return "\n".join(parts) def build_chat( obs: CollisionObservation, config: LLMAgentConfig = LLMAgentConfig(), ) -> List[Dict[str, str]]: return [ {"role": "system", "content": build_agent_system_prompt()}, {"role": "user", "content": render_user_prompt(obs, config)}, ] # ── Robust JSON extraction ─────────────────────────────────────────────── _JSON_RE = re.compile(r"\{[\s\S]*\}") def extract_first_json(text: str) -> Optional[Dict[str, Any]]: """Return the first parseable JSON object found inside ``text``.""" if not text: return None m = _JSON_RE.search(text) if not m: return None candidate = m.group(0) try: return json.loads(candidate) except json.JSONDecodeError: # Try a relaxed pass: trim trailing commas cleaned = re.sub(r",\s*([}\]])", r"\1", candidate) try: return json.loads(cleaned) except json.JSONDecodeError: return None def parse_action(text: str) -> Optional[ExperimentAction]: payload = extract_first_json(text) if payload is None: return None action_type = payload.get("action_type") if action_type not in _VALID_ACTIONS: return None try: return ExperimentAction( action_type=ActionType(action_type), method=payload.get("method") or None, parameters=payload.get("parameters") or {}, justification=payload.get("justification"), confidence=float(payload.get("confidence", 0.5) or 0.5), ) except Exception: return None def safe_default_action(obs: CollisionObservation) -> ExperimentAction: """Picks the next sensible scripted step when the LLM output is invalid.""" prog = obs.pipeline_history flags = {a.value: False for a in ActionType} for rec in prog: if rec.success: flags[rec.action_type.value] = True if not flags[ActionType.CONFIGURE_BEAM.value]: return ExperimentAction( action_type=ActionType.CONFIGURE_BEAM, parameters={"beam_energy": "13TeV"}, justification="default fallback", ) if not flags[ActionType.SELECT_CHANNEL.value]: return ExperimentAction( action_type=ActionType.SELECT_CHANNEL, parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"}, justification="default fallback", ) if not flags[ActionType.SET_TRIGGER.value]: return ExperimentAction( action_type=ActionType.SET_TRIGGER, parameters={"trigger": "diphoton_hlt"}, justification="default fallback", ) if not flags[ActionType.ALLOCATE_LUMINOSITY.value]: return ExperimentAction( action_type=ActionType.ALLOCATE_LUMINOSITY, parameters={"luminosity_fb": 50.0}, justification="default fallback", ) if not flags[ActionType.COLLECT_COLLISIONS.value]: return ExperimentAction( action_type=ActionType.COLLECT_COLLISIONS, parameters={"luminosity_fb": 50.0}, justification="default fallback", ) if not flags[ActionType.RECONSTRUCT_TRACKS.value]: return ExperimentAction( action_type=ActionType.RECONSTRUCT_TRACKS, justification="default fallback", ) if not flags[ActionType.BUILD_INVARIANT_MASS.value]: return ExperimentAction( action_type=ActionType.BUILD_INVARIANT_MASS, parameters={"mass_window_gev": obs.task.mass_search_window_gev}, justification="default fallback", ) if not flags[ActionType.FIT_RESONANCE.value]: return ExperimentAction( action_type=ActionType.FIT_RESONANCE, method="ROOT_RooFit", justification="default fallback", ) if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]: return ExperimentAction( action_type=ActionType.ESTIMATE_SIGNIFICANCE, method="Asimov_significance", justification="default fallback", ) mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0 return ExperimentAction( action_type=ActionType.SUBMIT_DISCOVERY_CLAIM, parameters={ "claim": { "mass_estimate_gev": mass, "mass_uncertainty_gev": 1.0, "significance_sigma": obs.cumulative_significance, "decay_channel": obs.selected_channel or "diphoton", "spin_hypothesis": 0, "parity": "+", "confidence": 0.7, } }, justification="default fallback claim", ) __all__ = [ "LLMAgentConfig", "build_chat", "extract_first_json", "parse_action", "render_history", "render_resources", "render_user_prompt", "safe_default_action", ]