Spaces:
Paused
Paused
| """Rollout collector for LLM-driven CERNenv episodes. | |
| Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and | |
| records full per-step trajectories: prompt, completion, parsed action, | |
| reward, observation snapshot, and final episode summary. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from dataclasses import asdict, dataclass, field | |
| from typing import Any, Callable, Dict, List, Optional | |
| from models import ActionType, CollisionObservation, ExperimentAction | |
| from server.environment import CERNCollisionEnvironment | |
| from .llm_agent import ( | |
| LLMAgentConfig, | |
| build_chat, | |
| parse_action, | |
| safe_default_action, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| PromptFn = Callable[[List[Dict[str, str]]], str] | |
| """Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template).""" | |
| GenerateFn = Callable[[str, LLMAgentConfig], str] | |
| """Callable: actually run the LLM and return the raw completion string.""" | |
| class StepRecord: | |
| step: int | |
| prompt: str | |
| completion: str | |
| action: Dict[str, Any] | |
| parsed_ok: bool | |
| reward: float | |
| done: bool | |
| rule_violations: List[str] | |
| observation_summary: Dict[str, Any] = field(default_factory=dict) | |
| class EpisodeRecord: | |
| seed: int | |
| scenario: Optional[str] | |
| difficulty: Optional[str] | |
| truth: Optional[Dict[str, Any]] | |
| total_reward: float | |
| cumulative_reward: float | |
| terminal_reward: Optional[float] | |
| discovered: Optional[bool] | |
| correct_mass: Optional[bool] | |
| correct_channel: Optional[bool] | |
| correct_spin: Optional[bool] | |
| steps: List[StepRecord] | |
| def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]: | |
| return { | |
| "step_index": obs.step_index, | |
| "selected_channel": obs.selected_channel, | |
| "selected_beam_energy": obs.selected_beam_energy, | |
| "n_candidates": len(obs.candidate_masses_gev), | |
| "best_significance": obs.cumulative_significance, | |
| "budget_remaining_musd": obs.resource_usage.budget_remaining_musd, | |
| "luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb, | |
| } | |
| def collect_episode( | |
| *, | |
| env: CERNCollisionEnvironment, | |
| seed: int, | |
| scenario: Optional[str], | |
| difficulty: Optional[str], | |
| prompt_fn: PromptFn, | |
| generate_fn: GenerateFn, | |
| config: LLMAgentConfig = LLMAgentConfig(), | |
| max_steps: Optional[int] = None, | |
| ) -> EpisodeRecord: | |
| obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty) | |
| steps: List[StepRecord] = [] | |
| total_reward = 0.0 | |
| cap = max_steps or env.max_steps | |
| while not obs.done and len(steps) < cap: | |
| chat = build_chat(obs, config) | |
| prompt = prompt_fn(chat) | |
| completion = generate_fn(prompt, config) | |
| action = parse_action(completion) | |
| parsed_ok = action is not None | |
| if action is None: | |
| action = safe_default_action(obs) | |
| next_obs = env.step(action) | |
| reward = float(next_obs.reward or 0.0) | |
| total_reward += reward | |
| steps.append( | |
| StepRecord( | |
| step=obs.step_index, | |
| prompt=prompt, | |
| completion=completion, | |
| action=action.model_dump(), | |
| parsed_ok=parsed_ok, | |
| reward=reward, | |
| done=next_obs.done, | |
| rule_violations=list(next_obs.rule_violations), | |
| observation_summary=_summarise_obs(obs), | |
| ) | |
| ) | |
| obs = next_obs | |
| return EpisodeRecord( | |
| seed=seed, | |
| scenario=env.state.scenario_name, | |
| difficulty=env.state.difficulty, | |
| truth=env.hidden_truth(), | |
| total_reward=total_reward, | |
| cumulative_reward=float(env.state.cumulative_reward), | |
| terminal_reward=env.state.terminal_reward, | |
| discovered=env.state.discovered, | |
| correct_mass=env.state.correct_mass, | |
| correct_channel=env.state.correct_channel, | |
| correct_spin=env.state.correct_spin, | |
| steps=steps, | |
| ) | |
| def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None: | |
| with open(path, "w") as f: | |
| for ep in episodes: | |
| f.write(json.dumps(asdict(ep), default=str) + "\n") | |
| def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]: | |
| eps: List[Dict[str, Any]] = [] | |
| with open(path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| eps.append(json.loads(line)) | |
| return eps | |
| __all__ = [ | |
| "EpisodeRecord", | |
| "StepRecord", | |
| "collect_episode", | |
| "save_episodes_jsonl", | |
| "load_episodes_jsonl", | |
| ] | |