| """Rollout collector for LLM-driven CERNenv episodes.
|
|
|
| Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and
|
| records full per-step trajectories: prompt, completion, parsed action,
|
| reward, observation snapshot, and final episode summary.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import json
|
| import logging
|
| from dataclasses import asdict, dataclass, field
|
| from typing import Any, Callable, Dict, List, Optional
|
|
|
| from models import ActionType, CollisionObservation, ExperimentAction
|
| from server.environment import CERNCollisionEnvironment
|
|
|
| from .llm_agent import (
|
| LLMAgentConfig,
|
| build_chat,
|
| parse_action,
|
| safe_default_action,
|
| )
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| PromptFn = Callable[[List[Dict[str, str]]], str]
|
| """Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template)."""
|
|
|
| GenerateFn = Callable[[str, LLMAgentConfig], str]
|
| """Callable: actually run the LLM and return the raw completion string."""
|
|
|
|
|
| @dataclass
|
| class StepRecord:
|
| step: int
|
| prompt: str
|
| completion: str
|
| action: Dict[str, Any]
|
| parsed_ok: bool
|
| reward: float
|
| done: bool
|
| rule_violations: List[str]
|
| observation_summary: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
| @dataclass
|
| class EpisodeRecord:
|
| seed: int
|
| scenario: Optional[str]
|
| difficulty: Optional[str]
|
| truth: Optional[Dict[str, Any]]
|
| total_reward: float
|
| cumulative_reward: float
|
| terminal_reward: Optional[float]
|
| discovered: Optional[bool]
|
| correct_mass: Optional[bool]
|
| correct_channel: Optional[bool]
|
| correct_spin: Optional[bool]
|
| steps: List[StepRecord]
|
|
|
|
|
| def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]:
|
| return {
|
| "step_index": obs.step_index,
|
| "selected_channel": obs.selected_channel,
|
| "selected_beam_energy": obs.selected_beam_energy,
|
| "n_candidates": len(obs.candidate_masses_gev),
|
| "best_significance": obs.cumulative_significance,
|
| "budget_remaining_musd": obs.resource_usage.budget_remaining_musd,
|
| "luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb,
|
| }
|
|
|
|
|
| def collect_episode(
|
| *,
|
| env: CERNCollisionEnvironment,
|
| seed: int,
|
| scenario: Optional[str],
|
| difficulty: Optional[str],
|
| prompt_fn: PromptFn,
|
| generate_fn: GenerateFn,
|
| config: LLMAgentConfig = LLMAgentConfig(),
|
| max_steps: Optional[int] = None,
|
| ) -> EpisodeRecord:
|
| obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
|
| steps: List[StepRecord] = []
|
| total_reward = 0.0
|
|
|
| cap = max_steps or env.max_steps
|
| while not obs.done and len(steps) < cap:
|
| chat = build_chat(obs, config)
|
| prompt = prompt_fn(chat)
|
| completion = generate_fn(prompt, config)
|
|
|
| action = parse_action(completion)
|
| parsed_ok = action is not None
|
| if action is None:
|
| action = safe_default_action(obs)
|
|
|
| next_obs = env.step(action)
|
| reward = float(next_obs.reward or 0.0)
|
| total_reward += reward
|
|
|
| steps.append(
|
| StepRecord(
|
| step=obs.step_index,
|
| prompt=prompt,
|
| completion=completion,
|
| action=action.model_dump(),
|
| parsed_ok=parsed_ok,
|
| reward=reward,
|
| done=next_obs.done,
|
| rule_violations=list(next_obs.rule_violations),
|
| observation_summary=_summarise_obs(obs),
|
| )
|
| )
|
| obs = next_obs
|
|
|
| return EpisodeRecord(
|
| seed=seed,
|
| scenario=env.state.scenario_name,
|
| difficulty=env.state.difficulty,
|
| truth=env.hidden_truth(),
|
| total_reward=total_reward,
|
| cumulative_reward=float(env.state.cumulative_reward),
|
| terminal_reward=env.state.terminal_reward,
|
| discovered=env.state.discovered,
|
| correct_mass=env.state.correct_mass,
|
| correct_channel=env.state.correct_channel,
|
| correct_spin=env.state.correct_spin,
|
| steps=steps,
|
| )
|
|
|
|
|
| def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None:
|
| with open(path, "w") as f:
|
| for ep in episodes:
|
| f.write(json.dumps(asdict(ep), default=str) + "\n")
|
|
|
|
|
| def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]:
|
| eps: List[Dict[str, Any]] = []
|
| with open(path) as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| eps.append(json.loads(line))
|
| return eps
|
|
|
|
|
| __all__ = [
|
| "EpisodeRecord",
|
| "StepRecord",
|
| "collect_episode",
|
| "save_episodes_jsonl",
|
| "load_episodes_jsonl",
|
| ]
|
|
|