"""GRPO trainer for the Drift Generator. Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition). Each prompt is sampled `group_size` times; for every breakage we run K independent Repair Agent rollouts to estimate p_hat (success rate). Heavy and brittle on a single GPU — keep group_size small for hackathon budgets. Provides a `--dry_run` mode that just exercises the reward function without any LLM calls. """ from __future__ import annotations import argparse import json import os import random from pathlib import Path from typing import Optional from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.roles.drift_generator import ( BaselineDriftGenerator, parse_drift_output, ) from forgeenv.roles.prompts import ( DRIFT_GENERATOR_SYSTEM_PROMPT, render_drift_generator_prompt, ) from forgeenv.training.rollout import ( GenerateFn, rollout_one_episode, baseline_oracle_repair_generate, _baseline_repair_generate, ) from forgeenv.training.reward_functions import ( compute_drift_gen_reward, compute_uncertainty_reward, compute_repetition_penalty, ) def evaluate_drift_batch( env_factory, breakages: list[dict], repair_generate: GenerateFn, n_repair_attempts_per_breakage: int = 4, seed: int = 0, ) -> list[float]: """For each breakage spec, run K Repair-Agent attempts and compute R-Zero's composite Challenger reward. Returns one reward per breakage.""" breakage_texts = [ f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}" for b in breakages ] rewards: list[float] = [] for idx, breakage_spec in enumerate(breakages): successes: list[bool] = [] for k in range(n_repair_attempts_per_breakage): env = env_factory() env.reset(seed=seed + idx * 100 + k, difficulty="easy") try: obs2 = env.step( ForgeAction( breakage=BreakageAction( primitive_type=breakage_spec.get("primitive_type", ""), params=breakage_spec.get("params", {}) or {}, ) ) ) except Exception: successes.append(False) continue from forgeenv.roles.repair_agent import extract_diff from forgeenv.roles.prompts import render_repair_agent_prompt user = render_repair_agent_prompt( broken_script=obs2.script_content, error_trace=obs2.error_trace or "", library_versions=obs2.library_versions, target_category=obs2.target_category, ) raw = repair_generate("", user) diff = extract_diff(raw or "") obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff))) successes.append( bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5) ) reward = compute_drift_gen_reward( breakage_text=breakage_texts[idx], repair_successes=successes, batch_breakages=breakage_texts, ) rewards.append(reward) return rewards def run_drift_grpo_dry_run( output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0 ) -> None: """Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards.""" rng = random.Random(seed) drift_gen = BaselineDriftGenerator(seed=seed) rewards_log: list[dict] = [] for ep in range(total_episodes): env = ForgeEnvironment(seed=seed + ep) env.reset(difficulty="easy") target_category = env.state["target_category"] script = env._original_script # noqa: SLF001 — read-only convenience # Sample group_size candidate breakages candidates = [ drift_gen.propose(target_category=target_category, script=script) for _ in range(group_size) ] # Use the oracle as repair (so we get a meaningful uncertainty signal: # an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0) rewards = evaluate_drift_batch( env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)), breakages=candidates, repair_generate=baseline_oracle_repair_generate(env), n_repair_attempts_per_breakage=2, seed=seed + ep, ) rewards_log.append( {"episode": ep, "rewards": rewards, "candidates": candidates} ) if ep % max(1, total_episodes // 10) == 0: mean_r = sum(rewards) / max(1, len(rewards)) print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}") Path(output_dir).mkdir(parents=True, exist_ok=True) (Path(output_dir) / "drift_dry_run.json").write_text( json.dumps(rewards_log, indent=2) ) print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}") def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--output_dir", required=True) parser.add_argument("--total_episodes", type=int, default=100) parser.add_argument("--group_size", type=int, default=4) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--dry_run", action="store_true", default=True) return parser.parse_args() if __name__ == "__main__": args = _parse_args() if args.dry_run: run_drift_grpo_dry_run( output_dir=args.output_dir, total_episodes=args.total_episodes, group_size=args.group_size, seed=args.seed, ) else: raise NotImplementedError( "Full LLM Drift GRPO requires both roles loaded — use the Colab notebook" )