cernenv / training /llm_agent.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""LLM (Large Language Model) agent that picks the next CERNenv action.
The agent renders an observation as a short prompt, asks the LLM for a
JSON-formatted ``ExperimentAction``, validates the response, and falls back
to a safe default action if parsing fails. This is the unit shared by
evaluation and the GRPO (Group-Relative Policy Optimization) training loop.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from models import (
ActionType,
CollisionObservation,
ExperimentAction,
build_agent_observation_context,
build_agent_system_prompt,
)
_VALID_ACTIONS = {a.value for a in ActionType}
@dataclass
class LLMAgentConfig:
"""Knobs for prompt formatting and decoding."""
max_history_steps: int = 6
temperature: float = 0.7
max_new_tokens: int = 256
top_p: float = 0.95
def render_history(obs: CollisionObservation, max_steps: int) -> str:
if not obs.pipeline_history:
return " (none yet — pick a starting action)"
lines: List[str] = []
history = obs.pipeline_history[-max_steps:]
for rec in history:
success = "OK" if rec.success else "FAIL"
lines.append(
f" step {rec.step_index:>2} {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}"
)
return "\n".join(lines)
def render_resources(obs: CollisionObservation) -> str:
r = obs.resource_usage
return (
f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, "
f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, "
f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left"
)
def render_user_prompt(
obs: CollisionObservation,
config: LLMAgentConfig = LLMAgentConfig(),
) -> str:
parts: List[str] = []
parts.append("Task:")
parts.append(" " + obs.task.problem_statement.strip())
parts.append("")
parts.append("Public state:")
parts.append(" " + build_agent_observation_context(obs).replace("\n", "\n "))
parts.append("")
parts.append("Resources:")
parts.append(" " + render_resources(obs))
parts.append("")
parts.append("Recent steps:")
parts.append(render_history(obs, max_steps=config.max_history_steps))
if obs.rule_violations:
parts.append("")
parts.append("Last-step violations: " + ", ".join(obs.rule_violations))
parts.append("")
parts.append("Choose ONE next action and respond with a single JSON object.")
return "\n".join(parts)
def build_chat(
obs: CollisionObservation,
config: LLMAgentConfig = LLMAgentConfig(),
) -> List[Dict[str, str]]:
return [
{"role": "system", "content": build_agent_system_prompt()},
{"role": "user", "content": render_user_prompt(obs, config)},
]
# ── Robust JSON extraction ───────────────────────────────────────────────
_JSON_RE = re.compile(r"\{[\s\S]*\}")
def extract_first_json(text: str) -> Optional[Dict[str, Any]]:
"""Return the first parseable JSON object found inside ``text``."""
if not text:
return None
m = _JSON_RE.search(text)
if not m:
return None
candidate = m.group(0)
try:
return json.loads(candidate)
except json.JSONDecodeError:
# Try a relaxed pass: trim trailing commas
cleaned = re.sub(r",\s*([}\]])", r"\1", candidate)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
return None
def parse_action(text: str) -> Optional[ExperimentAction]:
payload = extract_first_json(text)
if payload is None:
return None
action_type = payload.get("action_type")
if action_type not in _VALID_ACTIONS:
return None
try:
return ExperimentAction(
action_type=ActionType(action_type),
method=payload.get("method") or None,
parameters=payload.get("parameters") or {},
justification=payload.get("justification"),
confidence=float(payload.get("confidence", 0.5) or 0.5),
)
except Exception:
return None
def safe_default_action(obs: CollisionObservation) -> ExperimentAction:
"""Picks the next sensible scripted step when the LLM output is invalid."""
prog = obs.pipeline_history
flags = {a.value: False for a in ActionType}
for rec in prog:
if rec.success:
flags[rec.action_type.value] = True
if not flags[ActionType.CONFIGURE_BEAM.value]:
return ExperimentAction(
action_type=ActionType.CONFIGURE_BEAM,
parameters={"beam_energy": "13TeV"},
justification="default fallback",
)
if not flags[ActionType.SELECT_CHANNEL.value]:
return ExperimentAction(
action_type=ActionType.SELECT_CHANNEL,
parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"},
justification="default fallback",
)
if not flags[ActionType.SET_TRIGGER.value]:
return ExperimentAction(
action_type=ActionType.SET_TRIGGER,
parameters={"trigger": "diphoton_hlt"},
justification="default fallback",
)
if not flags[ActionType.ALLOCATE_LUMINOSITY.value]:
return ExperimentAction(
action_type=ActionType.ALLOCATE_LUMINOSITY,
parameters={"luminosity_fb": 50.0},
justification="default fallback",
)
if not flags[ActionType.COLLECT_COLLISIONS.value]:
return ExperimentAction(
action_type=ActionType.COLLECT_COLLISIONS,
parameters={"luminosity_fb": 50.0},
justification="default fallback",
)
if not flags[ActionType.RECONSTRUCT_TRACKS.value]:
return ExperimentAction(
action_type=ActionType.RECONSTRUCT_TRACKS,
justification="default fallback",
)
if not flags[ActionType.BUILD_INVARIANT_MASS.value]:
return ExperimentAction(
action_type=ActionType.BUILD_INVARIANT_MASS,
parameters={"mass_window_gev": obs.task.mass_search_window_gev},
justification="default fallback",
)
if not flags[ActionType.FIT_RESONANCE.value]:
return ExperimentAction(
action_type=ActionType.FIT_RESONANCE,
method="ROOT_RooFit",
justification="default fallback",
)
if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]:
return ExperimentAction(
action_type=ActionType.ESTIMATE_SIGNIFICANCE,
method="Asimov_significance",
justification="default fallback",
)
mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0
return ExperimentAction(
action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
parameters={
"claim": {
"mass_estimate_gev": mass,
"mass_uncertainty_gev": 1.0,
"significance_sigma": obs.cumulative_significance,
"decay_channel": obs.selected_channel or "diphoton",
"spin_hypothesis": 0,
"parity": "+",
"confidence": 0.7,
}
},
justification="default fallback claim",
)
__all__ = [
"LLMAgentConfig",
"build_chat",
"extract_first_json",
"parse_action",
"render_history",
"render_resources",
"render_user_prompt",
"safe_default_action",
]