Spaces:
Paused
Paused
| """LLM (Large Language Model) agent that picks the next CERNenv action. | |
| The agent renders an observation as a short prompt, asks the LLM for a | |
| JSON-formatted ``ExperimentAction``, validates the response, and falls back | |
| to a safe default action if parsing fails. This is the unit shared by | |
| evaluation and the GRPO (Group-Relative Policy Optimization) training loop. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| from models import ( | |
| ActionType, | |
| CollisionObservation, | |
| ExperimentAction, | |
| build_agent_observation_context, | |
| build_agent_system_prompt, | |
| ) | |
| _VALID_ACTIONS = {a.value for a in ActionType} | |
| class LLMAgentConfig: | |
| """Knobs for prompt formatting and decoding.""" | |
| max_history_steps: int = 6 | |
| temperature: float = 0.7 | |
| max_new_tokens: int = 256 | |
| top_p: float = 0.95 | |
| def render_history(obs: CollisionObservation, max_steps: int) -> str: | |
| if not obs.pipeline_history: | |
| return " (none yet — pick a starting action)" | |
| lines: List[str] = [] | |
| history = obs.pipeline_history[-max_steps:] | |
| for rec in history: | |
| success = "OK" if rec.success else "FAIL" | |
| lines.append( | |
| f" step {rec.step_index:>2} {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}" | |
| ) | |
| return "\n".join(lines) | |
| def render_resources(obs: CollisionObservation) -> str: | |
| r = obs.resource_usage | |
| return ( | |
| f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, " | |
| f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, " | |
| f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left" | |
| ) | |
| def render_user_prompt( | |
| obs: CollisionObservation, | |
| config: LLMAgentConfig = LLMAgentConfig(), | |
| ) -> str: | |
| parts: List[str] = [] | |
| parts.append("Task:") | |
| parts.append(" " + obs.task.problem_statement.strip()) | |
| parts.append("") | |
| parts.append("Public state:") | |
| parts.append(" " + build_agent_observation_context(obs).replace("\n", "\n ")) | |
| parts.append("") | |
| parts.append("Resources:") | |
| parts.append(" " + render_resources(obs)) | |
| parts.append("") | |
| parts.append("Recent steps:") | |
| parts.append(render_history(obs, max_steps=config.max_history_steps)) | |
| if obs.rule_violations: | |
| parts.append("") | |
| parts.append("Last-step violations: " + ", ".join(obs.rule_violations)) | |
| parts.append("") | |
| parts.append("Choose ONE next action and respond with a single JSON object.") | |
| return "\n".join(parts) | |
| def build_chat( | |
| obs: CollisionObservation, | |
| config: LLMAgentConfig = LLMAgentConfig(), | |
| ) -> List[Dict[str, str]]: | |
| return [ | |
| {"role": "system", "content": build_agent_system_prompt()}, | |
| {"role": "user", "content": render_user_prompt(obs, config)}, | |
| ] | |
| # ── Robust JSON extraction ─────────────────────────────────────────────── | |
| _JSON_RE = re.compile(r"\{[\s\S]*\}") | |
| def extract_first_json(text: str) -> Optional[Dict[str, Any]]: | |
| """Return the first parseable JSON object found inside ``text``.""" | |
| if not text: | |
| return None | |
| m = _JSON_RE.search(text) | |
| if not m: | |
| return None | |
| candidate = m.group(0) | |
| try: | |
| return json.loads(candidate) | |
| except json.JSONDecodeError: | |
| # Try a relaxed pass: trim trailing commas | |
| cleaned = re.sub(r",\s*([}\]])", r"\1", candidate) | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| return None | |
| def parse_action(text: str) -> Optional[ExperimentAction]: | |
| payload = extract_first_json(text) | |
| if payload is None: | |
| return None | |
| action_type = payload.get("action_type") | |
| if action_type not in _VALID_ACTIONS: | |
| return None | |
| try: | |
| return ExperimentAction( | |
| action_type=ActionType(action_type), | |
| method=payload.get("method") or None, | |
| parameters=payload.get("parameters") or {}, | |
| justification=payload.get("justification"), | |
| confidence=float(payload.get("confidence", 0.5) or 0.5), | |
| ) | |
| except Exception: | |
| return None | |
| def safe_default_action(obs: CollisionObservation) -> ExperimentAction: | |
| """Picks the next sensible scripted step when the LLM output is invalid.""" | |
| prog = obs.pipeline_history | |
| flags = {a.value: False for a in ActionType} | |
| for rec in prog: | |
| if rec.success: | |
| flags[rec.action_type.value] = True | |
| if not flags[ActionType.CONFIGURE_BEAM.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.CONFIGURE_BEAM, | |
| parameters={"beam_energy": "13TeV"}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.SELECT_CHANNEL.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.SELECT_CHANNEL, | |
| parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.SET_TRIGGER.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.SET_TRIGGER, | |
| parameters={"trigger": "diphoton_hlt"}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.ALLOCATE_LUMINOSITY.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.ALLOCATE_LUMINOSITY, | |
| parameters={"luminosity_fb": 50.0}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.COLLECT_COLLISIONS.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.COLLECT_COLLISIONS, | |
| parameters={"luminosity_fb": 50.0}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.RECONSTRUCT_TRACKS.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.RECONSTRUCT_TRACKS, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.BUILD_INVARIANT_MASS.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.BUILD_INVARIANT_MASS, | |
| parameters={"mass_window_gev": obs.task.mass_search_window_gev}, | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.FIT_RESONANCE.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.FIT_RESONANCE, | |
| method="ROOT_RooFit", | |
| justification="default fallback", | |
| ) | |
| if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]: | |
| return ExperimentAction( | |
| action_type=ActionType.ESTIMATE_SIGNIFICANCE, | |
| method="Asimov_significance", | |
| justification="default fallback", | |
| ) | |
| mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0 | |
| return ExperimentAction( | |
| action_type=ActionType.SUBMIT_DISCOVERY_CLAIM, | |
| parameters={ | |
| "claim": { | |
| "mass_estimate_gev": mass, | |
| "mass_uncertainty_gev": 1.0, | |
| "significance_sigma": obs.cumulative_significance, | |
| "decay_channel": obs.selected_channel or "diphoton", | |
| "spin_hypothesis": 0, | |
| "parity": "+", | |
| "confidence": 0.7, | |
| } | |
| }, | |
| justification="default fallback claim", | |
| ) | |
| __all__ = [ | |
| "LLMAgentConfig", | |
| "build_chat", | |
| "extract_first_json", | |
| "parse_action", | |
| "render_history", | |
| "render_resources", | |
| "render_user_prompt", | |
| "safe_default_action", | |
| ] | |