forgeenv-source / forgeenv /training /grpo_drift.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""GRPO trainer for the Drift Generator.
Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition).
Each prompt is sampled `group_size` times; for every breakage we run K
independent Repair Agent rollouts to estimate p_hat (success rate).
Heavy and brittle on a single GPU — keep group_size small for hackathon
budgets. Provides a `--dry_run` mode that just exercises the reward function
without any LLM calls.
"""
from __future__ import annotations
import argparse
import json
import os
import random
from pathlib import Path
from typing import Optional
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
render_drift_generator_prompt,
)
from forgeenv.training.rollout import (
GenerateFn,
rollout_one_episode,
baseline_oracle_repair_generate,
_baseline_repair_generate,
)
from forgeenv.training.reward_functions import (
compute_drift_gen_reward,
compute_uncertainty_reward,
compute_repetition_penalty,
)
def evaluate_drift_batch(
env_factory,
breakages: list[dict],
repair_generate: GenerateFn,
n_repair_attempts_per_breakage: int = 4,
seed: int = 0,
) -> list[float]:
"""For each breakage spec, run K Repair-Agent attempts and compute
R-Zero's composite Challenger reward. Returns one reward per breakage."""
breakage_texts = [
f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}"
for b in breakages
]
rewards: list[float] = []
for idx, breakage_spec in enumerate(breakages):
successes: list[bool] = []
for k in range(n_repair_attempts_per_breakage):
env = env_factory()
env.reset(seed=seed + idx * 100 + k, difficulty="easy")
try:
obs2 = env.step(
ForgeAction(
breakage=BreakageAction(
primitive_type=breakage_spec.get("primitive_type", ""),
params=breakage_spec.get("params", {}) or {},
)
)
)
except Exception:
successes.append(False)
continue
from forgeenv.roles.repair_agent import extract_diff
from forgeenv.roles.prompts import render_repair_agent_prompt
user = render_repair_agent_prompt(
broken_script=obs2.script_content,
error_trace=obs2.error_trace or "",
library_versions=obs2.library_versions,
target_category=obs2.target_category,
)
raw = repair_generate("", user)
diff = extract_diff(raw or "")
obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff)))
successes.append(
bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5)
)
reward = compute_drift_gen_reward(
breakage_text=breakage_texts[idx],
repair_successes=successes,
batch_breakages=breakage_texts,
)
rewards.append(reward)
return rewards
def run_drift_grpo_dry_run(
output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0
) -> None:
"""Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards."""
rng = random.Random(seed)
drift_gen = BaselineDriftGenerator(seed=seed)
rewards_log: list[dict] = []
for ep in range(total_episodes):
env = ForgeEnvironment(seed=seed + ep)
env.reset(difficulty="easy")
target_category = env.state["target_category"]
script = env._original_script # noqa: SLF001 — read-only convenience
# Sample group_size candidate breakages
candidates = [
drift_gen.propose(target_category=target_category, script=script)
for _ in range(group_size)
]
# Use the oracle as repair (so we get a meaningful uncertainty signal:
# an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0)
rewards = evaluate_drift_batch(
env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)),
breakages=candidates,
repair_generate=baseline_oracle_repair_generate(env),
n_repair_attempts_per_breakage=2,
seed=seed + ep,
)
rewards_log.append(
{"episode": ep, "rewards": rewards, "candidates": candidates}
)
if ep % max(1, total_episodes // 10) == 0:
mean_r = sum(rewards) / max(1, len(rewards))
print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}")
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "drift_dry_run.json").write_text(
json.dumps(rewards_log, indent=2)
)
print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}")
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--total_episodes", type=int, default=100)
parser.add_argument("--group_size", type=int, default=4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--dry_run", action="store_true", default=True)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
if args.dry_run:
run_drift_grpo_dry_run(
output_dir=args.output_dir,
total_episodes=args.total_episodes,
group_size=args.group_size,
seed=args.seed,
)
else:
raise NotImplementedError(
"Full LLM Drift GRPO requires both roles loaded — use the Colab notebook"
)