"""Rollout function: connects an LLM to ForgeEnvironment for a full episode. This is the function the GRPO trainer calls to convert a prompt into a trajectory + reward. It runs both phases of an episode (drift + repair) by asking the policy twice with role-switched prompts. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Optional from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.roles.drift_generator import ( BaselineDriftGenerator, parse_drift_output, ) from forgeenv.roles.prompts import ( DRIFT_GENERATOR_SYSTEM_PROMPT, REPAIR_AGENT_SYSTEM_PROMPT, render_drift_generator_prompt, render_repair_agent_prompt, ) from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff # Generation function signature: takes a (system, user) prompt pair and # returns the assistant's completion. We keep this abstract so we can plug # in TRL's batched generator, vLLM, or our deterministic baseline. GenerateFn = Callable[[str, str], str] @dataclass class RolloutResult: task_id: str primitive_type: str drift_prompt: str drift_completion: str repair_prompt: str repair_completion: str visible_reward: float visible_breakdown: dict[str, float] held_out_breakdown: dict[str, float] success: bool error_trace: str = "" info: dict[str, Any] = field(default_factory=dict) def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn: """Wrap our deterministic Drift Generator into a GenerateFn.""" gen = BaselineDriftGenerator(seed=0) def fn(system: str, user: str) -> str: target = "RenameApiCall" for line in user.splitlines(): if line.lower().startswith("target category:"): target = line.split(":", 1)[1].strip() break # Try to extract the script body so we can pick a primitive that # actually mutates it. script_block = "" if "```python" in user: script_block = user.split("```python", 1)[1].split("```", 1)[0] spec = gen.propose(target_category=target, script=script_block) import json return json.dumps(spec) return fn def _baseline_repair_generate() -> GenerateFn: """Wrap our deterministic Repair Agent into a GenerateFn. The baseline cheats by recovering the original script from the user prompt is impossible (we don't pass it). Instead, when called as a baseline it just returns an empty diff. Use BaselineDriftGenerator-paired tests (which read env.state) when you want the oracle path. """ def fn(system: str, user: str) -> str: return "" # baseline = no-op (intentional negative baseline) return fn def rollout_one_episode( env: ForgeEnvironment, drift_generate: Optional[GenerateFn] = None, repair_generate: Optional[GenerateFn] = None, difficulty: str = "easy", ) -> RolloutResult: """Run a single 2-step episode end-to-end and capture all signals.""" drift_generate = drift_generate or _baseline_drift_generate(env) repair_generate = repair_generate or _baseline_repair_generate() obs = env.reset(difficulty=difficulty) assert obs.current_phase == "drift_gen" # ---------- Phase 1: Drift Generator ---------- drift_prompt = render_drift_generator_prompt( script=obs.script_content, target_category=obs.target_category, library_versions=obs.library_versions, ) drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt) spec = parse_drift_output(drift_raw) if not spec: spec = {"primitive_type": "RenameApiCall", "params": {}} breakage_action = ForgeAction( breakage=BreakageAction( primitive_type=spec.get("primitive_type", "RenameApiCall"), params=spec.get("params", {}) or {}, ) ) obs2 = env.step(breakage_action) # ---------- Phase 2: Repair Agent ---------- repair_prompt = render_repair_agent_prompt( broken_script=obs2.script_content, error_trace=obs2.error_trace or "", library_versions=obs2.library_versions, target_category=obs2.target_category, ) repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt) diff = extract_diff(repair_raw) if repair_raw else "" repair_action = ForgeAction(repair=RepairAction(unified_diff=diff)) obs3 = env.step(repair_action) return RolloutResult( task_id=obs.task_id, primitive_type=spec.get("primitive_type", "RenameApiCall"), drift_prompt=drift_prompt, drift_completion=drift_raw, repair_prompt=repair_prompt, repair_completion=repair_raw, visible_reward=float(obs3.reward or 0.0), visible_breakdown=dict(obs3.reward_breakdown), held_out_breakdown=dict(obs3.held_out_breakdown), success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5), error_trace=obs3.error_trace or "", info=dict(obs3.info), ) def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn: """An "oracle" repair generator that reads the original script from `env.state` and emits a perfect diff. Useful for sanity-checking the end-to-end loop and as the upper-bound baseline in plots. """ repair_agent = BaselineRepairAgent() def fn(system: str, user: str) -> str: # Pull the original script out of env state via the task sampler task_id = env.state.get("task_id") if task_id is None: return "" task = env.task_sampler.get_by_id(task_id) if task is None: return "" # The current script in env._broken_script is what the user sees. broken = env._broken_script # noqa: SLF001 — internal but oracle-only return repair_agent.repair( broken, breakage_spec=env._breakage_spec, # noqa: SLF001 original_script=task.script_content, ) return fn