| """Rollout function: connects an LLM to ForgeEnvironment for a full episode. |
| |
| This is the function the GRPO trainer calls to convert a prompt into a |
| trajectory + reward. It runs both phases of an episode (drift + repair) by |
| asking the policy twice with role-switched prompts. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Any, Callable, Optional |
|
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction |
| from forgeenv.env.forge_environment import ForgeEnvironment |
| from forgeenv.roles.drift_generator import ( |
| BaselineDriftGenerator, |
| parse_drift_output, |
| ) |
| from forgeenv.roles.prompts import ( |
| DRIFT_GENERATOR_SYSTEM_PROMPT, |
| REPAIR_AGENT_SYSTEM_PROMPT, |
| render_drift_generator_prompt, |
| render_repair_agent_prompt, |
| ) |
| from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff |
|
|
|
|
| |
| |
| |
| GenerateFn = Callable[[str, str], str] |
|
|
|
|
| @dataclass |
| class RolloutResult: |
| task_id: str |
| primitive_type: str |
| drift_prompt: str |
| drift_completion: str |
| repair_prompt: str |
| repair_completion: str |
| visible_reward: float |
| visible_breakdown: dict[str, float] |
| held_out_breakdown: dict[str, float] |
| success: bool |
| error_trace: str = "" |
| info: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn: |
| """Wrap our deterministic Drift Generator into a GenerateFn.""" |
|
|
| gen = BaselineDriftGenerator(seed=0) |
|
|
| def fn(system: str, user: str) -> str: |
| target = "RenameApiCall" |
| for line in user.splitlines(): |
| if line.lower().startswith("target category:"): |
| target = line.split(":", 1)[1].strip() |
| break |
| |
| |
| script_block = "" |
| if "```python" in user: |
| script_block = user.split("```python", 1)[1].split("```", 1)[0] |
| spec = gen.propose(target_category=target, script=script_block) |
| import json |
|
|
| return json.dumps(spec) |
|
|
| return fn |
|
|
|
|
| def _baseline_repair_generate() -> GenerateFn: |
| """Wrap our deterministic Repair Agent into a GenerateFn. |
| |
| The baseline cheats by recovering the original script from the user |
| prompt is impossible (we don't pass it). Instead, when called as a |
| baseline it just returns an empty diff. Use BaselineDriftGenerator-paired |
| tests (which read env.state) when you want the oracle path. |
| """ |
|
|
| def fn(system: str, user: str) -> str: |
| return "" |
|
|
| return fn |
|
|
|
|
| def rollout_one_episode( |
| env: ForgeEnvironment, |
| drift_generate: Optional[GenerateFn] = None, |
| repair_generate: Optional[GenerateFn] = None, |
| difficulty: str = "easy", |
| ) -> RolloutResult: |
| """Run a single 2-step episode end-to-end and capture all signals.""" |
| drift_generate = drift_generate or _baseline_drift_generate(env) |
| repair_generate = repair_generate or _baseline_repair_generate() |
|
|
| obs = env.reset(difficulty=difficulty) |
| assert obs.current_phase == "drift_gen" |
|
|
| |
| drift_prompt = render_drift_generator_prompt( |
| script=obs.script_content, |
| target_category=obs.target_category, |
| library_versions=obs.library_versions, |
| ) |
| drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt) |
| spec = parse_drift_output(drift_raw) |
| if not spec: |
| spec = {"primitive_type": "RenameApiCall", "params": {}} |
|
|
| breakage_action = ForgeAction( |
| breakage=BreakageAction( |
| primitive_type=spec.get("primitive_type", "RenameApiCall"), |
| params=spec.get("params", {}) or {}, |
| ) |
| ) |
| obs2 = env.step(breakage_action) |
|
|
| |
| repair_prompt = render_repair_agent_prompt( |
| broken_script=obs2.script_content, |
| error_trace=obs2.error_trace or "", |
| library_versions=obs2.library_versions, |
| target_category=obs2.target_category, |
| ) |
| repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt) |
| diff = extract_diff(repair_raw) if repair_raw else "" |
|
|
| repair_action = ForgeAction(repair=RepairAction(unified_diff=diff)) |
| obs3 = env.step(repair_action) |
|
|
| return RolloutResult( |
| task_id=obs.task_id, |
| primitive_type=spec.get("primitive_type", "RenameApiCall"), |
| drift_prompt=drift_prompt, |
| drift_completion=drift_raw, |
| repair_prompt=repair_prompt, |
| repair_completion=repair_raw, |
| visible_reward=float(obs3.reward or 0.0), |
| visible_breakdown=dict(obs3.reward_breakdown), |
| held_out_breakdown=dict(obs3.held_out_breakdown), |
| success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5), |
| error_trace=obs3.error_trace or "", |
| info=dict(obs3.info), |
| ) |
|
|
|
|
| def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn: |
| """An "oracle" repair generator that reads the original script from |
| `env.state` and emits a perfect diff. Useful for sanity-checking the |
| end-to-end loop and as the upper-bound baseline in plots. |
| """ |
|
|
| repair_agent = BaselineRepairAgent() |
|
|
| def fn(system: str, user: str) -> str: |
| |
| task_id = env.state.get("task_id") |
| if task_id is None: |
| return "" |
| task = env.task_sampler.get_by_id(task_id) |
| if task is None: |
| return "" |
| |
| broken = env._broken_script |
| return repair_agent.repair( |
| broken, |
| breakage_spec=env._breakage_spec, |
| original_script=task.script_content, |
| ) |
|
|
| return fn |
|
|