| """GRPO trainer for the Drift Generator. |
| |
| Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition). |
| Each prompt is sampled `group_size` times; for every breakage we run K |
| independent Repair Agent rollouts to estimate p_hat (success rate). |
| |
| Heavy and brittle on a single GPU — keep group_size small for hackathon |
| budgets. Provides a `--dry_run` mode that just exercises the reward function |
| without any LLM calls. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| from pathlib import Path |
| from typing import Optional |
|
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction |
| from forgeenv.env.forge_environment import ForgeEnvironment |
| from forgeenv.roles.drift_generator import ( |
| BaselineDriftGenerator, |
| parse_drift_output, |
| ) |
| from forgeenv.roles.prompts import ( |
| DRIFT_GENERATOR_SYSTEM_PROMPT, |
| render_drift_generator_prompt, |
| ) |
| from forgeenv.training.rollout import ( |
| GenerateFn, |
| rollout_one_episode, |
| baseline_oracle_repair_generate, |
| _baseline_repair_generate, |
| ) |
| from forgeenv.training.reward_functions import ( |
| compute_drift_gen_reward, |
| compute_uncertainty_reward, |
| compute_repetition_penalty, |
| ) |
|
|
|
|
| def evaluate_drift_batch( |
| env_factory, |
| breakages: list[dict], |
| repair_generate: GenerateFn, |
| n_repair_attempts_per_breakage: int = 4, |
| seed: int = 0, |
| ) -> list[float]: |
| """For each breakage spec, run K Repair-Agent attempts and compute |
| R-Zero's composite Challenger reward. Returns one reward per breakage.""" |
|
|
| breakage_texts = [ |
| f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}" |
| for b in breakages |
| ] |
|
|
| rewards: list[float] = [] |
| for idx, breakage_spec in enumerate(breakages): |
| successes: list[bool] = [] |
| for k in range(n_repair_attempts_per_breakage): |
| env = env_factory() |
| env.reset(seed=seed + idx * 100 + k, difficulty="easy") |
| try: |
| obs2 = env.step( |
| ForgeAction( |
| breakage=BreakageAction( |
| primitive_type=breakage_spec.get("primitive_type", ""), |
| params=breakage_spec.get("params", {}) or {}, |
| ) |
| ) |
| ) |
| except Exception: |
| successes.append(False) |
| continue |
|
|
| from forgeenv.roles.repair_agent import extract_diff |
| from forgeenv.roles.prompts import render_repair_agent_prompt |
|
|
| user = render_repair_agent_prompt( |
| broken_script=obs2.script_content, |
| error_trace=obs2.error_trace or "", |
| library_versions=obs2.library_versions, |
| target_category=obs2.target_category, |
| ) |
| raw = repair_generate("", user) |
| diff = extract_diff(raw or "") |
| obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff))) |
| successes.append( |
| bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5) |
| ) |
|
|
| reward = compute_drift_gen_reward( |
| breakage_text=breakage_texts[idx], |
| repair_successes=successes, |
| batch_breakages=breakage_texts, |
| ) |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| def run_drift_grpo_dry_run( |
| output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0 |
| ) -> None: |
| """Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards.""" |
| rng = random.Random(seed) |
| drift_gen = BaselineDriftGenerator(seed=seed) |
| rewards_log: list[dict] = [] |
|
|
| for ep in range(total_episodes): |
| env = ForgeEnvironment(seed=seed + ep) |
| env.reset(difficulty="easy") |
| target_category = env.state["target_category"] |
| script = env._original_script |
|
|
| |
| candidates = [ |
| drift_gen.propose(target_category=target_category, script=script) |
| for _ in range(group_size) |
| ] |
|
|
| |
| |
| rewards = evaluate_drift_batch( |
| env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)), |
| breakages=candidates, |
| repair_generate=baseline_oracle_repair_generate(env), |
| n_repair_attempts_per_breakage=2, |
| seed=seed + ep, |
| ) |
| rewards_log.append( |
| {"episode": ep, "rewards": rewards, "candidates": candidates} |
| ) |
|
|
| if ep % max(1, total_episodes // 10) == 0: |
| mean_r = sum(rewards) / max(1, len(rewards)) |
| print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}") |
|
|
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| (Path(output_dir) / "drift_dry_run.json").write_text( |
| json.dumps(rewards_log, indent=2) |
| ) |
| print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}") |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--output_dir", required=True) |
| parser.add_argument("--total_episodes", type=int, default=100) |
| parser.add_argument("--group_size", type=int, default=4) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--dry_run", action="store_true", default=True) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| if args.dry_run: |
| run_drift_grpo_dry_run( |
| output_dir=args.output_dir, |
| total_episodes=args.total_episodes, |
| group_size=args.group_size, |
| seed=args.seed, |
| ) |
| else: |
| raise NotImplementedError( |
| "Full LLM Drift GRPO requires both roles loaded — use the Colab notebook" |
| ) |
|
|