Spaces:
Sleeping
Sleeping
File size: 4,748 Bytes
5f78183 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Rollout collector for LLM-driven CERNenv episodes.
Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and
records full per-step trajectories: prompt, completion, parsed action,
reward, observation snapshot, and final episode summary.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional
from models import ActionType, CollisionObservation, ExperimentAction
from server.environment import CERNCollisionEnvironment
from .llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
logger = logging.getLogger(__name__)
PromptFn = Callable[[List[Dict[str, str]]], str]
"""Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template)."""
GenerateFn = Callable[[str, LLMAgentConfig], str]
"""Callable: actually run the LLM and return the raw completion string."""
@dataclass
class StepRecord:
step: int
prompt: str
completion: str
action: Dict[str, Any]
parsed_ok: bool
reward: float
done: bool
rule_violations: List[str]
observation_summary: Dict[str, Any] = field(default_factory=dict)
@dataclass
class EpisodeRecord:
seed: int
scenario: Optional[str]
difficulty: Optional[str]
truth: Optional[Dict[str, Any]]
total_reward: float
cumulative_reward: float
terminal_reward: Optional[float]
discovered: Optional[bool]
correct_mass: Optional[bool]
correct_channel: Optional[bool]
correct_spin: Optional[bool]
steps: List[StepRecord]
def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]:
return {
"step_index": obs.step_index,
"selected_channel": obs.selected_channel,
"selected_beam_energy": obs.selected_beam_energy,
"n_candidates": len(obs.candidate_masses_gev),
"best_significance": obs.cumulative_significance,
"budget_remaining_musd": obs.resource_usage.budget_remaining_musd,
"luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb,
}
def collect_episode(
*,
env: CERNCollisionEnvironment,
seed: int,
scenario: Optional[str],
difficulty: Optional[str],
prompt_fn: PromptFn,
generate_fn: GenerateFn,
config: LLMAgentConfig = LLMAgentConfig(),
max_steps: Optional[int] = None,
) -> EpisodeRecord:
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
steps: List[StepRecord] = []
total_reward = 0.0
cap = max_steps or env.max_steps
while not obs.done and len(steps) < cap:
chat = build_chat(obs, config)
prompt = prompt_fn(chat)
completion = generate_fn(prompt, config)
action = parse_action(completion)
parsed_ok = action is not None
if action is None:
action = safe_default_action(obs)
next_obs = env.step(action)
reward = float(next_obs.reward or 0.0)
total_reward += reward
steps.append(
StepRecord(
step=obs.step_index,
prompt=prompt,
completion=completion,
action=action.model_dump(),
parsed_ok=parsed_ok,
reward=reward,
done=next_obs.done,
rule_violations=list(next_obs.rule_violations),
observation_summary=_summarise_obs(obs),
)
)
obs = next_obs
return EpisodeRecord(
seed=seed,
scenario=env.state.scenario_name,
difficulty=env.state.difficulty,
truth=env.hidden_truth(),
total_reward=total_reward,
cumulative_reward=float(env.state.cumulative_reward),
terminal_reward=env.state.terminal_reward,
discovered=env.state.discovered,
correct_mass=env.state.correct_mass,
correct_channel=env.state.correct_channel,
correct_spin=env.state.correct_spin,
steps=steps,
)
def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None:
with open(path, "w") as f:
for ep in episodes:
f.write(json.dumps(asdict(ep), default=str) + "\n")
def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]:
eps: List[Dict[str, Any]] = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
eps.append(json.loads(line))
return eps
__all__ = [
"EpisodeRecord",
"StepRecord",
"collect_episode",
"save_episodes_jsonl",
"load_episodes_jsonl",
]
|