Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Train GRPO policy with TRL using environment-backed verifier rewards.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.training.grpo_trl import GRPOTrlConfig, run_grpo_trl | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run TRL GRPO with env-backed rewards.") | |
| parser.add_argument("--model-id", default="Qwen/Qwen2.5-1.5B-Instruct") | |
| parser.add_argument("--prompts-path", default="data/processed/training_corpus_grpo_prompts.jsonl") | |
| parser.add_argument("--output-dir", default="checkpoints") | |
| parser.add_argument("--report-path", default="outputs/reports/grpo_trl_run.json") | |
| parser.add_argument("--max-prompts", type=int, default=256) | |
| parser.add_argument("--max-steps", type=int, default=30) | |
| parser.add_argument("--epochs", type=float, default=1.0) | |
| parser.add_argument("--episodes", type=int, default=0, help="Backward-compatible alias for --max-steps.") | |
| parser.add_argument("--batch-size", type=int, default=2) | |
| parser.add_argument("--grad-accum", type=int, default=1) | |
| parser.add_argument("--num-generations", type=int, default=2) | |
| parser.add_argument("--max-prompt-length", type=int, default=512) | |
| parser.add_argument("--max-completion-length", type=int, default=96) | |
| parser.add_argument("--learning-rate", type=float, default=1e-6) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--use-unsloth", action="store_true") | |
| parser.add_argument("--allow-fallback", action="store_true") | |
| parser.add_argument("--force-fallback", action="store_true") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| root = Path(__file__).resolve().parents[1] | |
| cfg = GRPOTrlConfig( | |
| model_id=args.model_id, | |
| prompts_path=root / args.prompts_path, | |
| output_dir=root / args.output_dir, | |
| max_prompts=args.max_prompts, | |
| max_steps=args.episodes if args.episodes > 0 else args.max_steps, | |
| epochs=args.epochs, | |
| per_device_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| num_generations=args.num_generations, | |
| learning_rate=args.learning_rate, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| temperature=args.temperature, | |
| seed=args.seed, | |
| use_unsloth=args.use_unsloth, | |
| allow_fallback=args.allow_fallback, | |
| force_fallback=args.force_fallback, | |
| ) | |
| result = run_grpo_trl(cfg) | |
| report_path = root / args.report_path | |
| report_path.parent.mkdir(parents=True, exist_ok=True) | |
| report_path.write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("grpo_trl_done") | |
| if __name__ == "__main__": | |
| main() | |