File size: 4,748 Bytes
5f78183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Rollout collector for LLM-driven CERNenv episodes.



Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and

records full per-step trajectories: prompt, completion, parsed action,

reward, observation snapshot, and final episode summary.

"""

from __future__ import annotations

import json
import logging
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional

from models import ActionType, CollisionObservation, ExperimentAction
from server.environment import CERNCollisionEnvironment

from .llm_agent import (
    LLMAgentConfig,
    build_chat,
    parse_action,
    safe_default_action,
)


logger = logging.getLogger(__name__)


PromptFn = Callable[[List[Dict[str, str]]], str]
"""Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template)."""

GenerateFn = Callable[[str, LLMAgentConfig], str]
"""Callable: actually run the LLM and return the raw completion string."""


@dataclass
class StepRecord:
    step: int
    prompt: str
    completion: str
    action: Dict[str, Any]
    parsed_ok: bool
    reward: float
    done: bool
    rule_violations: List[str]
    observation_summary: Dict[str, Any] = field(default_factory=dict)


@dataclass
class EpisodeRecord:
    seed: int
    scenario: Optional[str]
    difficulty: Optional[str]
    truth: Optional[Dict[str, Any]]
    total_reward: float
    cumulative_reward: float
    terminal_reward: Optional[float]
    discovered: Optional[bool]
    correct_mass: Optional[bool]
    correct_channel: Optional[bool]
    correct_spin: Optional[bool]
    steps: List[StepRecord]


def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]:
    return {
        "step_index": obs.step_index,
        "selected_channel": obs.selected_channel,
        "selected_beam_energy": obs.selected_beam_energy,
        "n_candidates": len(obs.candidate_masses_gev),
        "best_significance": obs.cumulative_significance,
        "budget_remaining_musd": obs.resource_usage.budget_remaining_musd,
        "luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb,
    }


def collect_episode(

    *,

    env: CERNCollisionEnvironment,

    seed: int,

    scenario: Optional[str],

    difficulty: Optional[str],

    prompt_fn: PromptFn,

    generate_fn: GenerateFn,

    config: LLMAgentConfig = LLMAgentConfig(),

    max_steps: Optional[int] = None,

) -> EpisodeRecord:
    obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
    steps: List[StepRecord] = []
    total_reward = 0.0

    cap = max_steps or env.max_steps
    while not obs.done and len(steps) < cap:
        chat = build_chat(obs, config)
        prompt = prompt_fn(chat)
        completion = generate_fn(prompt, config)

        action = parse_action(completion)
        parsed_ok = action is not None
        if action is None:
            action = safe_default_action(obs)

        next_obs = env.step(action)
        reward = float(next_obs.reward or 0.0)
        total_reward += reward

        steps.append(
            StepRecord(
                step=obs.step_index,
                prompt=prompt,
                completion=completion,
                action=action.model_dump(),
                parsed_ok=parsed_ok,
                reward=reward,
                done=next_obs.done,
                rule_violations=list(next_obs.rule_violations),
                observation_summary=_summarise_obs(obs),
            )
        )
        obs = next_obs

    return EpisodeRecord(
        seed=seed,
        scenario=env.state.scenario_name,
        difficulty=env.state.difficulty,
        truth=env.hidden_truth(),
        total_reward=total_reward,
        cumulative_reward=float(env.state.cumulative_reward),
        terminal_reward=env.state.terminal_reward,
        discovered=env.state.discovered,
        correct_mass=env.state.correct_mass,
        correct_channel=env.state.correct_channel,
        correct_spin=env.state.correct_spin,
        steps=steps,
    )


def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None:
    with open(path, "w") as f:
        for ep in episodes:
            f.write(json.dumps(asdict(ep), default=str) + "\n")


def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]:
    eps: List[Dict[str, Any]] = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                eps.append(json.loads(line))
    return eps


__all__ = [
    "EpisodeRecord",
    "StepRecord",
    "collect_episode",
    "save_episodes_jsonl",
    "load_episodes_jsonl",
]