File size: 1,628 Bytes
fd0c71a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env python3
"""Train SFT policy using TRL + Unsloth with fallback."""

from __future__ import annotations

import json
import os
from pathlib import Path

import sys

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from app.training.sft_trl import SFTRunConfig, run_sft_trl


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out = root / "checkpoints"
    out.mkdir(parents=True, exist_ok=True)

    model_id = os.getenv("POLYGUARD_SFT_MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
    dataset_path = root / "data" / "processed" / "sft_examples.json"
    run_cfg = SFTRunConfig(
        model_id=model_id,
        output_dir=out,
        dataset_path=dataset_path,
        max_seq_len=int(os.getenv("POLYGUARD_SFT_MAX_SEQ_LEN", "1024")),
        epochs=int(os.getenv("POLYGUARD_SFT_EPOCHS", "1")),
        learning_rate=float(os.getenv("POLYGUARD_SFT_LEARNING_RATE", "2e-5")),
        batch_size=int(os.getenv("POLYGUARD_SFT_BATCH_SIZE", "2")),
        max_steps=int(os.getenv("POLYGUARD_SFT_MAX_STEPS", "30")),
        use_unsloth=os.getenv("POLYGUARD_USE_UNSLOTH", "true").lower() in {"1", "true", "yes", "on"},
        allow_fallback=os.getenv("POLYGUARD_ALLOW_TRAIN_FALLBACK", "false").lower() in {"1", "true", "yes", "on"},
    )
    result = run_sft_trl(run_cfg)

    report_dir = root / "outputs" / "reports"
    report_dir.mkdir(parents=True, exist_ok=True)
    (report_dir / "sft_run.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8")
    print("sft_done")


if __name__ == "__main__":
    main()