File size: 6,008 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """GRPO trainer for the Drift Generator.
Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition).
Each prompt is sampled `group_size` times; for every breakage we run K
independent Repair Agent rollouts to estimate p_hat (success rate).
Heavy and brittle on a single GPU — keep group_size small for hackathon
budgets. Provides a `--dry_run` mode that just exercises the reward function
without any LLM calls.
"""
from __future__ import annotations
import argparse
import json
import os
import random
from pathlib import Path
from typing import Optional
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
render_drift_generator_prompt,
)
from forgeenv.training.rollout import (
GenerateFn,
rollout_one_episode,
baseline_oracle_repair_generate,
_baseline_repair_generate,
)
from forgeenv.training.reward_functions import (
compute_drift_gen_reward,
compute_uncertainty_reward,
compute_repetition_penalty,
)
def evaluate_drift_batch(
env_factory,
breakages: list[dict],
repair_generate: GenerateFn,
n_repair_attempts_per_breakage: int = 4,
seed: int = 0,
) -> list[float]:
"""For each breakage spec, run K Repair-Agent attempts and compute
R-Zero's composite Challenger reward. Returns one reward per breakage."""
breakage_texts = [
f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}"
for b in breakages
]
rewards: list[float] = []
for idx, breakage_spec in enumerate(breakages):
successes: list[bool] = []
for k in range(n_repair_attempts_per_breakage):
env = env_factory()
env.reset(seed=seed + idx * 100 + k, difficulty="easy")
try:
obs2 = env.step(
ForgeAction(
breakage=BreakageAction(
primitive_type=breakage_spec.get("primitive_type", ""),
params=breakage_spec.get("params", {}) or {},
)
)
)
except Exception:
successes.append(False)
continue
from forgeenv.roles.repair_agent import extract_diff
from forgeenv.roles.prompts import render_repair_agent_prompt
user = render_repair_agent_prompt(
broken_script=obs2.script_content,
error_trace=obs2.error_trace or "",
library_versions=obs2.library_versions,
target_category=obs2.target_category,
)
raw = repair_generate("", user)
diff = extract_diff(raw or "")
obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff)))
successes.append(
bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5)
)
reward = compute_drift_gen_reward(
breakage_text=breakage_texts[idx],
repair_successes=successes,
batch_breakages=breakage_texts,
)
rewards.append(reward)
return rewards
def run_drift_grpo_dry_run(
output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0
) -> None:
"""Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards."""
rng = random.Random(seed)
drift_gen = BaselineDriftGenerator(seed=seed)
rewards_log: list[dict] = []
for ep in range(total_episodes):
env = ForgeEnvironment(seed=seed + ep)
env.reset(difficulty="easy")
target_category = env.state["target_category"]
script = env._original_script # noqa: SLF001 — read-only convenience
# Sample group_size candidate breakages
candidates = [
drift_gen.propose(target_category=target_category, script=script)
for _ in range(group_size)
]
# Use the oracle as repair (so we get a meaningful uncertainty signal:
# an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0)
rewards = evaluate_drift_batch(
env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)),
breakages=candidates,
repair_generate=baseline_oracle_repair_generate(env),
n_repair_attempts_per_breakage=2,
seed=seed + ep,
)
rewards_log.append(
{"episode": ep, "rewards": rewards, "candidates": candidates}
)
if ep % max(1, total_episodes // 10) == 0:
mean_r = sum(rewards) / max(1, len(rewards))
print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}")
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "drift_dry_run.json").write_text(
json.dumps(rewards_log, indent=2)
)
print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}")
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--total_episodes", type=int, default=100)
parser.add_argument("--group_size", type=int, default=4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--dry_run", action="store_true", default=True)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
if args.dry_run:
run_drift_grpo_dry_run(
output_dir=args.output_dir,
total_episodes=args.total_episodes,
group_size=args.group_size,
seed=args.seed,
)
else:
raise NotImplementedError(
"Full LLM Drift GRPO requires both roles loaded — use the Colab notebook"
)
|