"""GRPO trainer for the Repair Agent. This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout function. Each prompt is sampled K times (group size); each sample is executed in the env and gets a scalar reward from the visible verifier. Usage: python -m forgeenv.training.grpo_repair \\ --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ --adapter_path artifacts/checkpoints/repair_agent_sft \\ --output_dir artifacts/checkpoints/repair_agent_grpo \\ --total_episodes 200 --group_size 4 """ from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any, Optional from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.roles.drift_generator import BaselineDriftGenerator from forgeenv.roles.prompts import ( DRIFT_GENERATOR_SYSTEM_PROMPT, REPAIR_AGENT_SYSTEM_PROMPT, render_drift_generator_prompt, render_repair_agent_prompt, ) from forgeenv.roles.repair_agent import extract_diff from forgeenv.training.rollout import rollout_one_episode def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]: """Reset env, run baseline drift generator, return a repair-prompt dict ready to feed to TRL's GRPOTrainer.""" drift_gen = BaselineDriftGenerator() obs = env.reset(difficulty="easy") drift_user = render_drift_generator_prompt( script=obs.script_content, target_category=obs.target_category, library_versions=obs.library_versions, ) spec = drift_gen.propose( target_category=obs.target_category, script=obs.script_content ) from forgeenv.env.actions import BreakageAction, ForgeAction obs2 = env.step( ForgeAction( breakage=BreakageAction( primitive_type=spec["primitive_type"], params=spec["params"] ) ) ) user = render_repair_agent_prompt( broken_script=obs2.script_content, error_trace=obs2.error_trace or "", library_versions=obs2.library_versions, target_category=obs2.target_category, ) return { "prompt": [ {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, {"role": "user", "content": user}, ], "task_id": obs.task_id, "primitive_type": spec["primitive_type"], "broken_script": obs2.script_content, "drift_user_prompt": drift_user, } def reward_repair_function( completions: list, prompts: list = None, **kwargs ) -> list[float]: """TRL-compatible reward fn: scores a batch of completions against a (broken_script, breakage_spec) tuple stored on each example.""" from forgeenv.env.actions import RepairAction, ForgeAction from forgeenv.env.diff_utils import apply_unified_diff from forgeenv.sandbox.simulation_mode import SimulationExecutor from forgeenv.tasks.task_sampler import TaskSampler from forgeenv.verifier.visible_verifier import compute_visible_reward sampler = TaskSampler() executor = SimulationExecutor() task_ids = kwargs.get("task_id", [None] * len(completions)) broken_scripts = kwargs.get("broken_script", [""] * len(completions)) rewards: list[float] = [] for completion, task_id, broken in zip(completions, task_ids, broken_scripts): if isinstance(completion, list): # chat format completion = completion[-1]["content"] diff = extract_diff(completion or "") repaired = apply_unified_diff(broken, diff) if diff else broken task = sampler.get_by_id(task_id) if task_id else None if task is None and sampler.tasks: task = sampler.tasks[0] result = executor.execute(repaired, task) result.script_content = repaired reward, _ = compute_visible_reward(result, task) rewards.append(float(reward)) return rewards def run_grpo( base_model: str, adapter_path: Optional[str], output_dir: str, total_episodes: int = 200, group_size: int = 4, learning_rate: float = 5e-6, seed: int = 0, use_unsloth: Optional[bool] = None, ) -> None: """Launch GRPO training (lazy imports to keep this module importable on CPU).""" if use_unsloth is None: use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" if not use_unsloth: # Dry-run mode: just exercise the prompt building loop and dump rewards. env = ForgeEnvironment(seed=seed) rewards = [] for ep in range(total_episodes): result = rollout_one_episode(env) rewards.append(result.visible_reward) if ep % max(1, total_episodes // 10) == 0: print( f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} " f"primitive={result.primitive_type}" ) Path(output_dir).mkdir(parents=True, exist_ok=True) (Path(output_dir) / "dry_run_rewards.json").write_text( json.dumps(rewards, indent=2) ) print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}") return from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import FastLanguageModel from peft import PeftModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=base_model, max_seq_length=4096, dtype=None, load_in_4bit=True, ) if adapter_path: model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True) env = ForgeEnvironment(seed=seed) examples = [_build_repair_prompt(env) for _ in range(total_episodes)] dataset = Dataset.from_list(examples) grpo_config = GRPOConfig( output_dir=output_dir, per_device_train_batch_size=1, gradient_accumulation_steps=4, learning_rate=learning_rate, max_steps=total_episodes, num_generations=group_size, max_completion_length=1024, logging_steps=5, save_steps=max(50, total_episodes // 4), save_total_limit=2, seed=seed, report_to="none", beta=0.04, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, args=grpo_config, train_dataset=dataset, reward_funcs=[reward_repair_function], ) trainer.train() Path(output_dir).mkdir(parents=True, exist_ok=True) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print(f"[grpo] saved adapter to {output_dir}") def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit") parser.add_argument("--adapter_path", default=None) parser.add_argument("--output_dir", required=True) parser.add_argument("--total_episodes", type=int, default=200) parser.add_argument("--group_size", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=5e-6) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--dry_run", action="store_true") return parser.parse_args() if __name__ == "__main__": args = _parse_args() run_grpo( base_model=args.base_model, adapter_path=args.adapter_path, output_dir=args.output_dir, total_episodes=args.total_episodes, group_size=args.group_size, learning_rate=args.learning_rate, seed=args.seed, use_unsloth=not args.dry_run, )