akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Rollout function: connects an LLM to ForgeEnvironment for a full episode.
This is the function the GRPO trainer calls to convert a prompt into a
trajectory + reward. It runs both phases of an episode (drift + repair) by
asking the policy twice with role-switched prompts.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff
# Generation function signature: takes a (system, user) prompt pair and
# returns the assistant's completion. We keep this abstract so we can plug
# in TRL's batched generator, vLLM, or our deterministic baseline.
GenerateFn = Callable[[str, str], str]
@dataclass
class RolloutResult:
task_id: str
primitive_type: str
drift_prompt: str
drift_completion: str
repair_prompt: str
repair_completion: str
visible_reward: float
visible_breakdown: dict[str, float]
held_out_breakdown: dict[str, float]
success: bool
error_trace: str = ""
info: dict[str, Any] = field(default_factory=dict)
def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn:
"""Wrap our deterministic Drift Generator into a GenerateFn."""
gen = BaselineDriftGenerator(seed=0)
def fn(system: str, user: str) -> str:
target = "RenameApiCall"
for line in user.splitlines():
if line.lower().startswith("target category:"):
target = line.split(":", 1)[1].strip()
break
# Try to extract the script body so we can pick a primitive that
# actually mutates it.
script_block = ""
if "```python" in user:
script_block = user.split("```python", 1)[1].split("```", 1)[0]
spec = gen.propose(target_category=target, script=script_block)
import json
return json.dumps(spec)
return fn
def _baseline_repair_generate() -> GenerateFn:
"""Wrap our deterministic Repair Agent into a GenerateFn.
The baseline cheats by recovering the original script from the user
prompt is impossible (we don't pass it). Instead, when called as a
baseline it just returns an empty diff. Use BaselineDriftGenerator-paired
tests (which read env.state) when you want the oracle path.
"""
def fn(system: str, user: str) -> str:
return "" # baseline = no-op (intentional negative baseline)
return fn
def rollout_one_episode(
env: ForgeEnvironment,
drift_generate: Optional[GenerateFn] = None,
repair_generate: Optional[GenerateFn] = None,
difficulty: str = "easy",
) -> RolloutResult:
"""Run a single 2-step episode end-to-end and capture all signals."""
drift_generate = drift_generate or _baseline_drift_generate(env)
repair_generate = repair_generate or _baseline_repair_generate()
obs = env.reset(difficulty=difficulty)
assert obs.current_phase == "drift_gen"
# ---------- Phase 1: Drift Generator ----------
drift_prompt = render_drift_generator_prompt(
script=obs.script_content,
target_category=obs.target_category,
library_versions=obs.library_versions,
)
drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt)
spec = parse_drift_output(drift_raw)
if not spec:
spec = {"primitive_type": "RenameApiCall", "params": {}}
breakage_action = ForgeAction(
breakage=BreakageAction(
primitive_type=spec.get("primitive_type", "RenameApiCall"),
params=spec.get("params", {}) or {},
)
)
obs2 = env.step(breakage_action)
# ---------- Phase 2: Repair Agent ----------
repair_prompt = render_repair_agent_prompt(
broken_script=obs2.script_content,
error_trace=obs2.error_trace or "",
library_versions=obs2.library_versions,
target_category=obs2.target_category,
)
repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt)
diff = extract_diff(repair_raw) if repair_raw else ""
repair_action = ForgeAction(repair=RepairAction(unified_diff=diff))
obs3 = env.step(repair_action)
return RolloutResult(
task_id=obs.task_id,
primitive_type=spec.get("primitive_type", "RenameApiCall"),
drift_prompt=drift_prompt,
drift_completion=drift_raw,
repair_prompt=repair_prompt,
repair_completion=repair_raw,
visible_reward=float(obs3.reward or 0.0),
visible_breakdown=dict(obs3.reward_breakdown),
held_out_breakdown=dict(obs3.held_out_breakdown),
success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5),
error_trace=obs3.error_trace or "",
info=dict(obs3.info),
)
def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn:
"""An "oracle" repair generator that reads the original script from
`env.state` and emits a perfect diff. Useful for sanity-checking the
end-to-end loop and as the upper-bound baseline in plots.
"""
repair_agent = BaselineRepairAgent()
def fn(system: str, user: str) -> str:
# Pull the original script out of env state via the task sampler
task_id = env.state.get("task_id")
if task_id is None:
return ""
task = env.task_sampler.get_by_id(task_id)
if task is None:
return ""
# The current script in env._broken_script is what the user sees.
broken = env._broken_script # noqa: SLF001 — internal but oracle-only
return repair_agent.repair(
broken,
breakage_spec=env._breakage_spec, # noqa: SLF001
original_script=task.script_content,
)
return fn