forgeenv-source / forgeenv /training /grpo_repair.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""GRPO trainer for the Repair Agent.
This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout
function. Each prompt is sampled K times (group size); each sample is
executed in the env and gets a scalar reward from the visible verifier.
Usage:
python -m forgeenv.training.grpo_repair \\
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
--adapter_path artifacts/checkpoints/repair_agent_sft \\
--output_dir artifacts/checkpoints/repair_agent_grpo \\
--total_episodes 200 --group_size 4
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any, Optional
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import BaselineDriftGenerator
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import extract_diff
from forgeenv.training.rollout import rollout_one_episode
def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]:
"""Reset env, run baseline drift generator, return a repair-prompt
dict ready to feed to TRL's GRPOTrainer."""
drift_gen = BaselineDriftGenerator()
obs = env.reset(difficulty="easy")
drift_user = render_drift_generator_prompt(
script=obs.script_content,
target_category=obs.target_category,
library_versions=obs.library_versions,
)
spec = drift_gen.propose(
target_category=obs.target_category, script=obs.script_content
)
from forgeenv.env.actions import BreakageAction, ForgeAction
obs2 = env.step(
ForgeAction(
breakage=BreakageAction(
primitive_type=spec["primitive_type"], params=spec["params"]
)
)
)
user = render_repair_agent_prompt(
broken_script=obs2.script_content,
error_trace=obs2.error_trace or "",
library_versions=obs2.library_versions,
target_category=obs2.target_category,
)
return {
"prompt": [
{"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
{"role": "user", "content": user},
],
"task_id": obs.task_id,
"primitive_type": spec["primitive_type"],
"broken_script": obs2.script_content,
"drift_user_prompt": drift_user,
}
def reward_repair_function(
completions: list, prompts: list = None, **kwargs
) -> list[float]:
"""TRL-compatible reward fn: scores a batch of completions against
a (broken_script, breakage_spec) tuple stored on each example."""
from forgeenv.env.actions import RepairAction, ForgeAction
from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.sandbox.simulation_mode import SimulationExecutor
from forgeenv.tasks.task_sampler import TaskSampler
from forgeenv.verifier.visible_verifier import compute_visible_reward
sampler = TaskSampler()
executor = SimulationExecutor()
task_ids = kwargs.get("task_id", [None] * len(completions))
broken_scripts = kwargs.get("broken_script", [""] * len(completions))
rewards: list[float] = []
for completion, task_id, broken in zip(completions, task_ids, broken_scripts):
if isinstance(completion, list): # chat format
completion = completion[-1]["content"]
diff = extract_diff(completion or "")
repaired = apply_unified_diff(broken, diff) if diff else broken
task = sampler.get_by_id(task_id) if task_id else None
if task is None and sampler.tasks:
task = sampler.tasks[0]
result = executor.execute(repaired, task)
result.script_content = repaired
reward, _ = compute_visible_reward(result, task)
rewards.append(float(reward))
return rewards
def run_grpo(
base_model: str,
adapter_path: Optional[str],
output_dir: str,
total_episodes: int = 200,
group_size: int = 4,
learning_rate: float = 5e-6,
seed: int = 0,
use_unsloth: Optional[bool] = None,
) -> None:
"""Launch GRPO training (lazy imports to keep this module importable on CPU)."""
if use_unsloth is None:
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
if not use_unsloth:
# Dry-run mode: just exercise the prompt building loop and dump rewards.
env = ForgeEnvironment(seed=seed)
rewards = []
for ep in range(total_episodes):
result = rollout_one_episode(env)
rewards.append(result.visible_reward)
if ep % max(1, total_episodes // 10) == 0:
print(
f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} "
f"primitive={result.primitive_type}"
)
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "dry_run_rewards.json").write_text(
json.dumps(rewards, indent=2)
)
print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}")
return
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel
from peft import PeftModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model,
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
)
if adapter_path:
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
env = ForgeEnvironment(seed=seed)
examples = [_build_repair_prompt(env) for _ in range(total_episodes)]
dataset = Dataset.from_list(examples)
grpo_config = GRPOConfig(
output_dir=output_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
max_steps=total_episodes,
num_generations=group_size,
max_completion_length=1024,
logging_steps=5,
save_steps=max(50, total_episodes // 4),
save_total_limit=2,
seed=seed,
report_to="none",
beta=0.04,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=grpo_config,
train_dataset=dataset,
reward_funcs=[reward_repair_function],
)
trainer.train()
Path(output_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"[grpo] saved adapter to {output_dir}")
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit")
parser.add_argument("--adapter_path", default=None)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--total_episodes", type=int, default=200)
parser.add_argument("--group_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--dry_run", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
run_grpo(
base_model=args.base_model,
adapter_path=args.adapter_path,
output_dir=args.output_dir,
total_episodes=args.total_episodes,
group_size=args.group_size,
learning_rate=args.learning_rate,
seed=args.seed,
use_unsloth=not args.dry_run,
)