| """LLM (Large Language Model) agent that picks the next CERNenv action.
|
|
|
| The agent renders an observation as a short prompt, asks the LLM for a
|
| JSON-formatted ``ExperimentAction``, validates the response, and falls back
|
| to a safe default action if parsing fails. This is the unit shared by
|
| evaluation and the GRPO (Group-Relative Policy Optimization) training loop.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import json
|
| import re
|
| from dataclasses import dataclass
|
| from typing import Any, Dict, List, Optional
|
|
|
| from models import (
|
| ActionType,
|
| CollisionObservation,
|
| ExperimentAction,
|
| build_agent_observation_context,
|
| build_agent_system_prompt,
|
| )
|
|
|
|
|
| _VALID_ACTIONS = {a.value for a in ActionType}
|
|
|
|
|
| @dataclass
|
| class LLMAgentConfig:
|
| """Knobs for prompt formatting and decoding."""
|
|
|
| max_history_steps: int = 6
|
| temperature: float = 0.7
|
| max_new_tokens: int = 256
|
| top_p: float = 0.95
|
|
|
|
|
| def render_history(obs: CollisionObservation, max_steps: int) -> str:
|
| if not obs.pipeline_history:
|
| return " (none yet — pick a starting action)"
|
| lines: List[str] = []
|
| history = obs.pipeline_history[-max_steps:]
|
| for rec in history:
|
| success = "OK" if rec.success else "FAIL"
|
| lines.append(
|
| f" step {rec.step_index:>2} {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}"
|
| )
|
| return "\n".join(lines)
|
|
|
|
|
| def render_resources(obs: CollisionObservation) -> str:
|
| r = obs.resource_usage
|
| return (
|
| f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, "
|
| f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, "
|
| f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left"
|
| )
|
|
|
|
|
| def render_user_prompt(
|
| obs: CollisionObservation,
|
| config: LLMAgentConfig = LLMAgentConfig(),
|
| ) -> str:
|
| parts: List[str] = []
|
| parts.append("Task:")
|
| parts.append(" " + obs.task.problem_statement.strip())
|
| parts.append("")
|
| parts.append("Public state:")
|
| parts.append(" " + build_agent_observation_context(obs).replace("\n", "\n "))
|
| parts.append("")
|
| parts.append("Resources:")
|
| parts.append(" " + render_resources(obs))
|
| parts.append("")
|
| parts.append("Recent steps:")
|
| parts.append(render_history(obs, max_steps=config.max_history_steps))
|
| if obs.rule_violations:
|
| parts.append("")
|
| parts.append("Last-step violations: " + ", ".join(obs.rule_violations))
|
| parts.append("")
|
| parts.append("Choose ONE next action and respond with a single JSON object.")
|
| return "\n".join(parts)
|
|
|
|
|
| def build_chat(
|
| obs: CollisionObservation,
|
| config: LLMAgentConfig = LLMAgentConfig(),
|
| ) -> List[Dict[str, str]]:
|
| return [
|
| {"role": "system", "content": build_agent_system_prompt()},
|
| {"role": "user", "content": render_user_prompt(obs, config)},
|
| ]
|
|
|
|
|
|
|
|
|
|
|
| _JSON_RE = re.compile(r"\{[\s\S]*\}")
|
|
|
|
|
| def extract_first_json(text: str) -> Optional[Dict[str, Any]]:
|
| """Return the first parseable JSON object found inside ``text``."""
|
| if not text:
|
| return None
|
| m = _JSON_RE.search(text)
|
| if not m:
|
| return None
|
| candidate = m.group(0)
|
| try:
|
| return json.loads(candidate)
|
| except json.JSONDecodeError:
|
|
|
| cleaned = re.sub(r",\s*([}\]])", r"\1", candidate)
|
| try:
|
| return json.loads(cleaned)
|
| except json.JSONDecodeError:
|
| return None
|
|
|
|
|
| def parse_action(text: str) -> Optional[ExperimentAction]:
|
| payload = extract_first_json(text)
|
| if payload is None:
|
| return None
|
| action_type = payload.get("action_type")
|
| if action_type not in _VALID_ACTIONS:
|
| return None
|
| try:
|
| return ExperimentAction(
|
| action_type=ActionType(action_type),
|
| method=payload.get("method") or None,
|
| parameters=payload.get("parameters") or {},
|
| justification=payload.get("justification"),
|
| confidence=float(payload.get("confidence", 0.5) or 0.5),
|
| )
|
| except Exception:
|
| return None
|
|
|
|
|
| def safe_default_action(obs: CollisionObservation) -> ExperimentAction:
|
| """Picks the next sensible scripted step when the LLM output is invalid."""
|
| prog = obs.pipeline_history
|
| flags = {a.value: False for a in ActionType}
|
| for rec in prog:
|
| if rec.success:
|
| flags[rec.action_type.value] = True
|
|
|
| if not flags[ActionType.CONFIGURE_BEAM.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.CONFIGURE_BEAM,
|
| parameters={"beam_energy": "13TeV"},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.SELECT_CHANNEL.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.SELECT_CHANNEL,
|
| parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.SET_TRIGGER.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.SET_TRIGGER,
|
| parameters={"trigger": "diphoton_hlt"},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.ALLOCATE_LUMINOSITY.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.ALLOCATE_LUMINOSITY,
|
| parameters={"luminosity_fb": 50.0},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.COLLECT_COLLISIONS.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.COLLECT_COLLISIONS,
|
| parameters={"luminosity_fb": 50.0},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.RECONSTRUCT_TRACKS.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.RECONSTRUCT_TRACKS,
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.BUILD_INVARIANT_MASS.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.BUILD_INVARIANT_MASS,
|
| parameters={"mass_window_gev": obs.task.mass_search_window_gev},
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.FIT_RESONANCE.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.FIT_RESONANCE,
|
| method="ROOT_RooFit",
|
| justification="default fallback",
|
| )
|
| if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]:
|
| return ExperimentAction(
|
| action_type=ActionType.ESTIMATE_SIGNIFICANCE,
|
| method="Asimov_significance",
|
| justification="default fallback",
|
| )
|
|
|
| mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0
|
| return ExperimentAction(
|
| action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
|
| parameters={
|
| "claim": {
|
| "mass_estimate_gev": mass,
|
| "mass_uncertainty_gev": 1.0,
|
| "significance_sigma": obs.cumulative_significance,
|
| "decay_channel": obs.selected_channel or "diphoton",
|
| "spin_hypothesis": 0,
|
| "parity": "+",
|
| "confidence": 0.7,
|
| }
|
| },
|
| justification="default fallback claim",
|
| )
|
|
|
|
|
| __all__ = [
|
| "LLMAgentConfig",
|
| "build_chat",
|
| "extract_first_json",
|
| "parse_action",
|
| "render_history",
|
| "render_resources",
|
| "render_user_prompt",
|
| "safe_default_action",
|
| ]
|
|
|