File size: 3,038 Bytes
fd0c71a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
"""Train GRPO policy with TRL using environment-backed verifier rewards."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import sys

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from app.training.grpo_trl import GRPOTrlConfig, run_grpo_trl


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run TRL GRPO with env-backed rewards.")
    parser.add_argument("--model-id", default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--prompts-path", default="data/processed/training_corpus_grpo_prompts.jsonl")
    parser.add_argument("--output-dir", default="checkpoints")
    parser.add_argument("--report-path", default="outputs/reports/grpo_trl_run.json")
    parser.add_argument("--max-prompts", type=int, default=256)
    parser.add_argument("--max-steps", type=int, default=30)
    parser.add_argument("--epochs", type=float, default=1.0)
    parser.add_argument("--episodes", type=int, default=0, help="Backward-compatible alias for --max-steps.")
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--grad-accum", type=int, default=1)
    parser.add_argument("--num-generations", type=int, default=2)
    parser.add_argument("--max-prompt-length", type=int, default=512)
    parser.add_argument("--max-completion-length", type=int, default=96)
    parser.add_argument("--learning-rate", type=float, default=1e-6)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--use-unsloth", action="store_true")
    parser.add_argument("--allow-fallback", action="store_true")
    parser.add_argument("--force-fallback", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    root = Path(__file__).resolve().parents[1]

    cfg = GRPOTrlConfig(
        model_id=args.model_id,
        prompts_path=root / args.prompts_path,
        output_dir=root / args.output_dir,
        max_prompts=args.max_prompts,
        max_steps=args.episodes if args.episodes > 0 else args.max_steps,
        epochs=args.epochs,
        per_device_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        num_generations=args.num_generations,
        learning_rate=args.learning_rate,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        temperature=args.temperature,
        seed=args.seed,
        use_unsloth=args.use_unsloth,
        allow_fallback=args.allow_fallback,
        force_fallback=args.force_fallback,
    )

    result = run_grpo_trl(cfg)

    report_path = root / args.report_path
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8")
    print("grpo_trl_done")


if __name__ == "__main__":
    main()