File size: 7,974 Bytes
5f78183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""LLM (Large Language Model) agent that picks the next CERNenv action.



The agent renders an observation as a short prompt, asks the LLM for a

JSON-formatted ``ExperimentAction``, validates the response, and falls back

to a safe default action if parsing fails. This is the unit shared by

evaluation and the GRPO (Group-Relative Policy Optimization) training loop.

"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from models import (
    ActionType,
    CollisionObservation,
    ExperimentAction,
    build_agent_observation_context,
    build_agent_system_prompt,
)


_VALID_ACTIONS = {a.value for a in ActionType}


@dataclass
class LLMAgentConfig:
    """Knobs for prompt formatting and decoding."""

    max_history_steps: int = 6
    temperature: float = 0.7
    max_new_tokens: int = 256
    top_p: float = 0.95


def render_history(obs: CollisionObservation, max_steps: int) -> str:
    if not obs.pipeline_history:
        return "  (none yet — pick a starting action)"
    lines: List[str] = []
    history = obs.pipeline_history[-max_steps:]
    for rec in history:
        success = "OK" if rec.success else "FAIL"
        lines.append(
            f"  step {rec.step_index:>2}  {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}"
        )
    return "\n".join(lines)


def render_resources(obs: CollisionObservation) -> str:
    r = obs.resource_usage
    return (
        f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, "
        f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, "
        f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left"
    )


def render_user_prompt(

    obs: CollisionObservation,

    config: LLMAgentConfig = LLMAgentConfig(),

) -> str:
    parts: List[str] = []
    parts.append("Task:")
    parts.append("  " + obs.task.problem_statement.strip())
    parts.append("")
    parts.append("Public state:")
    parts.append("  " + build_agent_observation_context(obs).replace("\n", "\n  "))
    parts.append("")
    parts.append("Resources:")
    parts.append("  " + render_resources(obs))
    parts.append("")
    parts.append("Recent steps:")
    parts.append(render_history(obs, max_steps=config.max_history_steps))
    if obs.rule_violations:
        parts.append("")
        parts.append("Last-step violations: " + ", ".join(obs.rule_violations))
    parts.append("")
    parts.append("Choose ONE next action and respond with a single JSON object.")
    return "\n".join(parts)


def build_chat(

    obs: CollisionObservation,

    config: LLMAgentConfig = LLMAgentConfig(),

) -> List[Dict[str, str]]:
    return [
        {"role": "system", "content": build_agent_system_prompt()},
        {"role": "user", "content": render_user_prompt(obs, config)},
    ]


# ── Robust JSON extraction ───────────────────────────────────────────────


_JSON_RE = re.compile(r"\{[\s\S]*\}")


def extract_first_json(text: str) -> Optional[Dict[str, Any]]:
    """Return the first parseable JSON object found inside ``text``."""
    if not text:
        return None
    m = _JSON_RE.search(text)
    if not m:
        return None
    candidate = m.group(0)
    try:
        return json.loads(candidate)
    except json.JSONDecodeError:
        # Try a relaxed pass: trim trailing commas
        cleaned = re.sub(r",\s*([}\]])", r"\1", candidate)
        try:
            return json.loads(cleaned)
        except json.JSONDecodeError:
            return None


def parse_action(text: str) -> Optional[ExperimentAction]:
    payload = extract_first_json(text)
    if payload is None:
        return None
    action_type = payload.get("action_type")
    if action_type not in _VALID_ACTIONS:
        return None
    try:
        return ExperimentAction(
            action_type=ActionType(action_type),
            method=payload.get("method") or None,
            parameters=payload.get("parameters") or {},
            justification=payload.get("justification"),
            confidence=float(payload.get("confidence", 0.5) or 0.5),
        )
    except Exception:
        return None


def safe_default_action(obs: CollisionObservation) -> ExperimentAction:
    """Picks the next sensible scripted step when the LLM output is invalid."""
    prog = obs.pipeline_history
    flags = {a.value: False for a in ActionType}
    for rec in prog:
        if rec.success:
            flags[rec.action_type.value] = True

    if not flags[ActionType.CONFIGURE_BEAM.value]:
        return ExperimentAction(
            action_type=ActionType.CONFIGURE_BEAM,
            parameters={"beam_energy": "13TeV"},
            justification="default fallback",
        )
    if not flags[ActionType.SELECT_CHANNEL.value]:
        return ExperimentAction(
            action_type=ActionType.SELECT_CHANNEL,
            parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"},
            justification="default fallback",
        )
    if not flags[ActionType.SET_TRIGGER.value]:
        return ExperimentAction(
            action_type=ActionType.SET_TRIGGER,
            parameters={"trigger": "diphoton_hlt"},
            justification="default fallback",
        )
    if not flags[ActionType.ALLOCATE_LUMINOSITY.value]:
        return ExperimentAction(
            action_type=ActionType.ALLOCATE_LUMINOSITY,
            parameters={"luminosity_fb": 50.0},
            justification="default fallback",
        )
    if not flags[ActionType.COLLECT_COLLISIONS.value]:
        return ExperimentAction(
            action_type=ActionType.COLLECT_COLLISIONS,
            parameters={"luminosity_fb": 50.0},
            justification="default fallback",
        )
    if not flags[ActionType.RECONSTRUCT_TRACKS.value]:
        return ExperimentAction(
            action_type=ActionType.RECONSTRUCT_TRACKS,
            justification="default fallback",
        )
    if not flags[ActionType.BUILD_INVARIANT_MASS.value]:
        return ExperimentAction(
            action_type=ActionType.BUILD_INVARIANT_MASS,
            parameters={"mass_window_gev": obs.task.mass_search_window_gev},
            justification="default fallback",
        )
    if not flags[ActionType.FIT_RESONANCE.value]:
        return ExperimentAction(
            action_type=ActionType.FIT_RESONANCE,
            method="ROOT_RooFit",
            justification="default fallback",
        )
    if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]:
        return ExperimentAction(
            action_type=ActionType.ESTIMATE_SIGNIFICANCE,
            method="Asimov_significance",
            justification="default fallback",
        )

    mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0
    return ExperimentAction(
        action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
        parameters={
            "claim": {
                "mass_estimate_gev": mass,
                "mass_uncertainty_gev": 1.0,
                "significance_sigma": obs.cumulative_significance,
                "decay_channel": obs.selected_channel or "diphoton",
                "spin_hypothesis": 0,
                "parity": "+",
                "confidence": 0.7,
            }
        },
        justification="default fallback claim",
    )


__all__ = [
    "LLMAgentConfig",
    "build_chat",
    "extract_first_json",
    "parse_action",
    "render_history",
    "render_resources",
    "render_user_prompt",
    "safe_default_action",
]