cernenv-trainer / training /rollouts.py
anugrahhu's picture
Update CERNenv Space
0a6c641 verified
"""Rollout collector for LLM-driven CERNenv episodes.
Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and
records full per-step trajectories: prompt, completion, parsed action,
reward, observation snapshot, and final episode summary.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional
from models import ActionType, CollisionObservation, ExperimentAction
from server.environment import CERNCollisionEnvironment
from .llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
logger = logging.getLogger(__name__)
PromptFn = Callable[[List[Dict[str, str]]], str]
"""Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template)."""
GenerateFn = Callable[[str, LLMAgentConfig], str]
"""Callable: actually run the LLM and return the raw completion string."""
@dataclass
class StepRecord:
step: int
prompt: str
completion: str
action: Dict[str, Any]
parsed_ok: bool
reward: float
done: bool
rule_violations: List[str]
observation_summary: Dict[str, Any] = field(default_factory=dict)
@dataclass
class EpisodeRecord:
seed: int
scenario: Optional[str]
difficulty: Optional[str]
truth: Optional[Dict[str, Any]]
total_reward: float
cumulative_reward: float
terminal_reward: Optional[float]
discovered: Optional[bool]
correct_mass: Optional[bool]
correct_channel: Optional[bool]
correct_spin: Optional[bool]
steps: List[StepRecord]
def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]:
return {
"step_index": obs.step_index,
"selected_channel": obs.selected_channel,
"selected_beam_energy": obs.selected_beam_energy,
"n_candidates": len(obs.candidate_masses_gev),
"best_significance": obs.cumulative_significance,
"budget_remaining_musd": obs.resource_usage.budget_remaining_musd,
"luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb,
}
def collect_episode(
*,
env: CERNCollisionEnvironment,
seed: int,
scenario: Optional[str],
difficulty: Optional[str],
prompt_fn: PromptFn,
generate_fn: GenerateFn,
config: LLMAgentConfig = LLMAgentConfig(),
max_steps: Optional[int] = None,
) -> EpisodeRecord:
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
steps: List[StepRecord] = []
total_reward = 0.0
cap = max_steps or env.max_steps
while not obs.done and len(steps) < cap:
chat = build_chat(obs, config)
prompt = prompt_fn(chat)
completion = generate_fn(prompt, config)
action = parse_action(completion)
parsed_ok = action is not None
if action is None:
action = safe_default_action(obs)
next_obs = env.step(action)
reward = float(next_obs.reward or 0.0)
total_reward += reward
steps.append(
StepRecord(
step=obs.step_index,
prompt=prompt,
completion=completion,
action=action.model_dump(),
parsed_ok=parsed_ok,
reward=reward,
done=next_obs.done,
rule_violations=list(next_obs.rule_violations),
observation_summary=_summarise_obs(obs),
)
)
obs = next_obs
return EpisodeRecord(
seed=seed,
scenario=env.state.scenario_name,
difficulty=env.state.difficulty,
truth=env.hidden_truth(),
total_reward=total_reward,
cumulative_reward=float(env.state.cumulative_reward),
terminal_reward=env.state.terminal_reward,
discovered=env.state.discovered,
correct_mass=env.state.correct_mass,
correct_channel=env.state.correct_channel,
correct_spin=env.state.correct_spin,
steps=steps,
)
def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None:
with open(path, "w") as f:
for ep in episodes:
f.write(json.dumps(asdict(ep), default=str) + "\n")
def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]:
eps: List[Dict[str, Any]] = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
eps.append(json.loads(line))
return eps
__all__ = [
"EpisodeRecord",
"StepRecord",
"collect_episode",
"save_episodes_jsonl",
"load_episodes_jsonl",
]