Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Train SFT policy using TRL + Unsloth with fallback.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.training.sft_trl import SFTRunConfig, run_sft_trl | |
| def main() -> None: | |
| root = Path(__file__).resolve().parents[1] | |
| out = root / "checkpoints" | |
| out.mkdir(parents=True, exist_ok=True) | |
| model_id = os.getenv("POLYGUARD_SFT_MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct") | |
| dataset_path = root / "data" / "processed" / "sft_examples.json" | |
| run_cfg = SFTRunConfig( | |
| model_id=model_id, | |
| output_dir=out, | |
| dataset_path=dataset_path, | |
| max_seq_len=int(os.getenv("POLYGUARD_SFT_MAX_SEQ_LEN", "1024")), | |
| epochs=int(os.getenv("POLYGUARD_SFT_EPOCHS", "1")), | |
| learning_rate=float(os.getenv("POLYGUARD_SFT_LEARNING_RATE", "2e-5")), | |
| batch_size=int(os.getenv("POLYGUARD_SFT_BATCH_SIZE", "2")), | |
| max_steps=int(os.getenv("POLYGUARD_SFT_MAX_STEPS", "30")), | |
| use_unsloth=os.getenv("POLYGUARD_USE_UNSLOTH", "true").lower() in {"1", "true", "yes", "on"}, | |
| allow_fallback=os.getenv("POLYGUARD_ALLOW_TRAIN_FALLBACK", "false").lower() in {"1", "true", "yes", "on"}, | |
| ) | |
| result = run_sft_trl(run_cfg) | |
| report_dir = root / "outputs" / "reports" | |
| report_dir.mkdir(parents=True, exist_ok=True) | |
| (report_dir / "sft_run.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("sft_done") | |
| if __name__ == "__main__": | |
| main() | |