"""Rollout collector for LLM-driven CERNenv episodes. Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and records full per-step trajectories: prompt, completion, parsed action, reward, observation snapshot, and final episode summary. """ from __future__ import annotations import json import logging from dataclasses import asdict, dataclass, field from typing import Any, Callable, Dict, List, Optional from models import ActionType, CollisionObservation, ExperimentAction from server.environment import CERNCollisionEnvironment from .llm_agent import ( LLMAgentConfig, build_chat, parse_action, safe_default_action, ) logger = logging.getLogger(__name__) PromptFn = Callable[[List[Dict[str, str]]], str] """Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template).""" GenerateFn = Callable[[str, LLMAgentConfig], str] """Callable: actually run the LLM and return the raw completion string.""" @dataclass class StepRecord: step: int prompt: str completion: str action: Dict[str, Any] parsed_ok: bool reward: float done: bool rule_violations: List[str] observation_summary: Dict[str, Any] = field(default_factory=dict) @dataclass class EpisodeRecord: seed: int scenario: Optional[str] difficulty: Optional[str] truth: Optional[Dict[str, Any]] total_reward: float cumulative_reward: float terminal_reward: Optional[float] discovered: Optional[bool] correct_mass: Optional[bool] correct_channel: Optional[bool] correct_spin: Optional[bool] steps: List[StepRecord] def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]: return { "step_index": obs.step_index, "selected_channel": obs.selected_channel, "selected_beam_energy": obs.selected_beam_energy, "n_candidates": len(obs.candidate_masses_gev), "best_significance": obs.cumulative_significance, "budget_remaining_musd": obs.resource_usage.budget_remaining_musd, "luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb, } def collect_episode( *, env: CERNCollisionEnvironment, seed: int, scenario: Optional[str], difficulty: Optional[str], prompt_fn: PromptFn, generate_fn: GenerateFn, config: LLMAgentConfig = LLMAgentConfig(), max_steps: Optional[int] = None, ) -> EpisodeRecord: obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty) steps: List[StepRecord] = [] total_reward = 0.0 cap = max_steps or env.max_steps while not obs.done and len(steps) < cap: chat = build_chat(obs, config) prompt = prompt_fn(chat) completion = generate_fn(prompt, config) action = parse_action(completion) parsed_ok = action is not None if action is None: action = safe_default_action(obs) next_obs = env.step(action) reward = float(next_obs.reward or 0.0) total_reward += reward steps.append( StepRecord( step=obs.step_index, prompt=prompt, completion=completion, action=action.model_dump(), parsed_ok=parsed_ok, reward=reward, done=next_obs.done, rule_violations=list(next_obs.rule_violations), observation_summary=_summarise_obs(obs), ) ) obs = next_obs return EpisodeRecord( seed=seed, scenario=env.state.scenario_name, difficulty=env.state.difficulty, truth=env.hidden_truth(), total_reward=total_reward, cumulative_reward=float(env.state.cumulative_reward), terminal_reward=env.state.terminal_reward, discovered=env.state.discovered, correct_mass=env.state.correct_mass, correct_channel=env.state.correct_channel, correct_spin=env.state.correct_spin, steps=steps, ) def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None: with open(path, "w") as f: for ep in episodes: f.write(json.dumps(asdict(ep), default=str) + "\n") def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]: eps: List[Dict[str, Any]] = [] with open(path) as f: for line in f: line = line.strip() if line: eps.append(json.loads(line)) return eps __all__ = [ "EpisodeRecord", "StepRecord", "collect_episode", "save_episodes_jsonl", "load_episodes_jsonl", ]