File size: 6,153 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """Rollout function: connects an LLM to ForgeEnvironment for a full episode.
This is the function the GRPO trainer calls to convert a prompt into a
trajectory + reward. It runs both phases of an episode (drift + repair) by
asking the policy twice with role-switched prompts.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff
# Generation function signature: takes a (system, user) prompt pair and
# returns the assistant's completion. We keep this abstract so we can plug
# in TRL's batched generator, vLLM, or our deterministic baseline.
GenerateFn = Callable[[str, str], str]
@dataclass
class RolloutResult:
task_id: str
primitive_type: str
drift_prompt: str
drift_completion: str
repair_prompt: str
repair_completion: str
visible_reward: float
visible_breakdown: dict[str, float]
held_out_breakdown: dict[str, float]
success: bool
error_trace: str = ""
info: dict[str, Any] = field(default_factory=dict)
def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn:
"""Wrap our deterministic Drift Generator into a GenerateFn."""
gen = BaselineDriftGenerator(seed=0)
def fn(system: str, user: str) -> str:
target = "RenameApiCall"
for line in user.splitlines():
if line.lower().startswith("target category:"):
target = line.split(":", 1)[1].strip()
break
# Try to extract the script body so we can pick a primitive that
# actually mutates it.
script_block = ""
if "```python" in user:
script_block = user.split("```python", 1)[1].split("```", 1)[0]
spec = gen.propose(target_category=target, script=script_block)
import json
return json.dumps(spec)
return fn
def _baseline_repair_generate() -> GenerateFn:
"""Wrap our deterministic Repair Agent into a GenerateFn.
The baseline cheats by recovering the original script from the user
prompt is impossible (we don't pass it). Instead, when called as a
baseline it just returns an empty diff. Use BaselineDriftGenerator-paired
tests (which read env.state) when you want the oracle path.
"""
def fn(system: str, user: str) -> str:
return "" # baseline = no-op (intentional negative baseline)
return fn
def rollout_one_episode(
env: ForgeEnvironment,
drift_generate: Optional[GenerateFn] = None,
repair_generate: Optional[GenerateFn] = None,
difficulty: str = "easy",
) -> RolloutResult:
"""Run a single 2-step episode end-to-end and capture all signals."""
drift_generate = drift_generate or _baseline_drift_generate(env)
repair_generate = repair_generate or _baseline_repair_generate()
obs = env.reset(difficulty=difficulty)
assert obs.current_phase == "drift_gen"
# ---------- Phase 1: Drift Generator ----------
drift_prompt = render_drift_generator_prompt(
script=obs.script_content,
target_category=obs.target_category,
library_versions=obs.library_versions,
)
drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt)
spec = parse_drift_output(drift_raw)
if not spec:
spec = {"primitive_type": "RenameApiCall", "params": {}}
breakage_action = ForgeAction(
breakage=BreakageAction(
primitive_type=spec.get("primitive_type", "RenameApiCall"),
params=spec.get("params", {}) or {},
)
)
obs2 = env.step(breakage_action)
# ---------- Phase 2: Repair Agent ----------
repair_prompt = render_repair_agent_prompt(
broken_script=obs2.script_content,
error_trace=obs2.error_trace or "",
library_versions=obs2.library_versions,
target_category=obs2.target_category,
)
repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt)
diff = extract_diff(repair_raw) if repair_raw else ""
repair_action = ForgeAction(repair=RepairAction(unified_diff=diff))
obs3 = env.step(repair_action)
return RolloutResult(
task_id=obs.task_id,
primitive_type=spec.get("primitive_type", "RenameApiCall"),
drift_prompt=drift_prompt,
drift_completion=drift_raw,
repair_prompt=repair_prompt,
repair_completion=repair_raw,
visible_reward=float(obs3.reward or 0.0),
visible_breakdown=dict(obs3.reward_breakdown),
held_out_breakdown=dict(obs3.held_out_breakdown),
success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5),
error_trace=obs3.error_trace or "",
info=dict(obs3.info),
)
def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn:
"""An "oracle" repair generator that reads the original script from
`env.state` and emits a perfect diff. Useful for sanity-checking the
end-to-end loop and as the upper-bound baseline in plots.
"""
repair_agent = BaselineRepairAgent()
def fn(system: str, user: str) -> str:
# Pull the original script out of env state via the task sampler
task_id = env.state.get("task_id")
if task_id is None:
return ""
task = env.task_sampler.get_by_id(task_id)
if task is None:
return ""
# The current script in env._broken_script is what the user sees.
broken = env._broken_script # noqa: SLF001 — internal but oracle-only
return repair_agent.repair(
broken,
breakage_spec=env._breakage_spec, # noqa: SLF001
original_script=task.script_content,
)
return fn
|