| """GRPO trainer for the Repair Agent. |
| |
| This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout |
| function. Each prompt is sampled K times (group size); each sample is |
| executed in the env and gets a scalar reward from the visible verifier. |
| |
| Usage: |
| python -m forgeenv.training.grpo_repair \\ |
| --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ |
| --adapter_path artifacts/checkpoints/repair_agent_sft \\ |
| --output_dir artifacts/checkpoints/repair_agent_grpo \\ |
| --total_episodes 200 --group_size 4 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| from forgeenv.env.forge_environment import ForgeEnvironment |
| from forgeenv.roles.drift_generator import BaselineDriftGenerator |
| from forgeenv.roles.prompts import ( |
| DRIFT_GENERATOR_SYSTEM_PROMPT, |
| REPAIR_AGENT_SYSTEM_PROMPT, |
| render_drift_generator_prompt, |
| render_repair_agent_prompt, |
| ) |
| from forgeenv.roles.repair_agent import extract_diff |
| from forgeenv.training.rollout import rollout_one_episode |
|
|
|
|
| def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]: |
| """Reset env, run baseline drift generator, return a repair-prompt |
| dict ready to feed to TRL's GRPOTrainer.""" |
| drift_gen = BaselineDriftGenerator() |
|
|
| obs = env.reset(difficulty="easy") |
| drift_user = render_drift_generator_prompt( |
| script=obs.script_content, |
| target_category=obs.target_category, |
| library_versions=obs.library_versions, |
| ) |
| spec = drift_gen.propose( |
| target_category=obs.target_category, script=obs.script_content |
| ) |
| from forgeenv.env.actions import BreakageAction, ForgeAction |
|
|
| obs2 = env.step( |
| ForgeAction( |
| breakage=BreakageAction( |
| primitive_type=spec["primitive_type"], params=spec["params"] |
| ) |
| ) |
| ) |
|
|
| user = render_repair_agent_prompt( |
| broken_script=obs2.script_content, |
| error_trace=obs2.error_trace or "", |
| library_versions=obs2.library_versions, |
| target_category=obs2.target_category, |
| ) |
| return { |
| "prompt": [ |
| {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, |
| {"role": "user", "content": user}, |
| ], |
| "task_id": obs.task_id, |
| "primitive_type": spec["primitive_type"], |
| "broken_script": obs2.script_content, |
| "drift_user_prompt": drift_user, |
| } |
|
|
|
|
| def reward_repair_function( |
| completions: list, prompts: list = None, **kwargs |
| ) -> list[float]: |
| """TRL-compatible reward fn: scores a batch of completions against |
| a (broken_script, breakage_spec) tuple stored on each example.""" |
| from forgeenv.env.actions import RepairAction, ForgeAction |
| from forgeenv.env.diff_utils import apply_unified_diff |
| from forgeenv.sandbox.simulation_mode import SimulationExecutor |
| from forgeenv.tasks.task_sampler import TaskSampler |
| from forgeenv.verifier.visible_verifier import compute_visible_reward |
|
|
| sampler = TaskSampler() |
| executor = SimulationExecutor() |
| task_ids = kwargs.get("task_id", [None] * len(completions)) |
| broken_scripts = kwargs.get("broken_script", [""] * len(completions)) |
|
|
| rewards: list[float] = [] |
| for completion, task_id, broken in zip(completions, task_ids, broken_scripts): |
| if isinstance(completion, list): |
| completion = completion[-1]["content"] |
| diff = extract_diff(completion or "") |
| repaired = apply_unified_diff(broken, diff) if diff else broken |
| task = sampler.get_by_id(task_id) if task_id else None |
| if task is None and sampler.tasks: |
| task = sampler.tasks[0] |
| result = executor.execute(repaired, task) |
| result.script_content = repaired |
| reward, _ = compute_visible_reward(result, task) |
| rewards.append(float(reward)) |
| return rewards |
|
|
|
|
| def run_grpo( |
| base_model: str, |
| adapter_path: Optional[str], |
| output_dir: str, |
| total_episodes: int = 200, |
| group_size: int = 4, |
| learning_rate: float = 5e-6, |
| seed: int = 0, |
| use_unsloth: Optional[bool] = None, |
| ) -> None: |
| """Launch GRPO training (lazy imports to keep this module importable on CPU).""" |
|
|
| if use_unsloth is None: |
| use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" |
|
|
| if not use_unsloth: |
| |
| env = ForgeEnvironment(seed=seed) |
| rewards = [] |
| for ep in range(total_episodes): |
| result = rollout_one_episode(env) |
| rewards.append(result.visible_reward) |
| if ep % max(1, total_episodes // 10) == 0: |
| print( |
| f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} " |
| f"primitive={result.primitive_type}" |
| ) |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| (Path(output_dir) / "dry_run_rewards.json").write_text( |
| json.dumps(rewards, indent=2) |
| ) |
| print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}") |
| return |
|
|
| from datasets import Dataset |
| from trl import GRPOConfig, GRPOTrainer |
| from unsloth import FastLanguageModel |
| from peft import PeftModel |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=base_model, |
| max_seq_length=4096, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
| if adapter_path: |
| model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True) |
|
|
| env = ForgeEnvironment(seed=seed) |
| examples = [_build_repair_prompt(env) for _ in range(total_episodes)] |
| dataset = Dataset.from_list(examples) |
|
|
| grpo_config = GRPOConfig( |
| output_dir=output_dir, |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=4, |
| learning_rate=learning_rate, |
| max_steps=total_episodes, |
| num_generations=group_size, |
| max_completion_length=1024, |
| logging_steps=5, |
| save_steps=max(50, total_episodes // 4), |
| save_total_limit=2, |
| seed=seed, |
| report_to="none", |
| beta=0.04, |
| ) |
| trainer = GRPOTrainer( |
| model=model, |
| processing_class=tokenizer, |
| args=grpo_config, |
| train_dataset=dataset, |
| reward_funcs=[reward_repair_function], |
| ) |
| trainer.train() |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| print(f"[grpo] saved adapter to {output_dir}") |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit") |
| parser.add_argument("--adapter_path", default=None) |
| parser.add_argument("--output_dir", required=True) |
| parser.add_argument("--total_episodes", type=int, default=200) |
| parser.add_argument("--group_size", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=5e-6) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--dry_run", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| run_grpo( |
| base_model=args.base_model, |
| adapter_path=args.adapter_path, |
| output_dir=args.output_dir, |
| total_episodes=args.total_episodes, |
| group_size=args.group_size, |
| learning_rate=args.learning_rate, |
| seed=args.seed, |
| use_unsloth=not args.dry_run, |
| ) |
|
|