Spaces:
Running
Running
File size: 1,628 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | #!/usr/bin/env python3
"""Train SFT policy using TRL + Unsloth with fallback."""
from __future__ import annotations
import json
import os
from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.training.sft_trl import SFTRunConfig, run_sft_trl
def main() -> None:
root = Path(__file__).resolve().parents[1]
out = root / "checkpoints"
out.mkdir(parents=True, exist_ok=True)
model_id = os.getenv("POLYGUARD_SFT_MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
dataset_path = root / "data" / "processed" / "sft_examples.json"
run_cfg = SFTRunConfig(
model_id=model_id,
output_dir=out,
dataset_path=dataset_path,
max_seq_len=int(os.getenv("POLYGUARD_SFT_MAX_SEQ_LEN", "1024")),
epochs=int(os.getenv("POLYGUARD_SFT_EPOCHS", "1")),
learning_rate=float(os.getenv("POLYGUARD_SFT_LEARNING_RATE", "2e-5")),
batch_size=int(os.getenv("POLYGUARD_SFT_BATCH_SIZE", "2")),
max_steps=int(os.getenv("POLYGUARD_SFT_MAX_STEPS", "30")),
use_unsloth=os.getenv("POLYGUARD_USE_UNSLOTH", "true").lower() in {"1", "true", "yes", "on"},
allow_fallback=os.getenv("POLYGUARD_ALLOW_TRAIN_FALLBACK", "false").lower() in {"1", "true", "yes", "on"},
)
result = run_sft_trl(run_cfg)
report_dir = root / "outputs" / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
(report_dir / "sft_run.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8")
print("sft_done")
if __name__ == "__main__":
main()
|