#!/usr/bin/env python3 """Explicit TRL SFT entrypoint for small/scale profiles.""" from __future__ import annotations import argparse import json from pathlib import Path import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.training.sft_trl import SFTRunConfig, run_sft_trl def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train SFT adapter with TRL + Unsloth.") parser.add_argument("--model-id", default="Qwen/Qwen2.5-1.5B-Instruct") parser.add_argument("--dataset-path", default="data/processed/sft_examples.json") parser.add_argument("--output-dir", default="checkpoints") parser.add_argument("--report-path", default="outputs/reports/sft_trl_run.json") parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--max-steps", type=int, default=30) parser.add_argument("--max-seq-len", type=int, default=1024) parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--use-unsloth", action="store_true") parser.add_argument("--allow-fallback", action="store_true") return parser.parse_args() def main() -> None: args = parse_args() root = Path(__file__).resolve().parents[1] cfg = SFTRunConfig( model_id=args.model_id, output_dir=root / args.output_dir, dataset_path=root / args.dataset_path, epochs=args.epochs, batch_size=args.batch_size, max_steps=args.max_steps, max_seq_len=args.max_seq_len, learning_rate=args.learning_rate, use_unsloth=args.use_unsloth, allow_fallback=args.allow_fallback, ) result = run_sft_trl(cfg) report_path = root / args.report_path report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") print("sft_trl_done") if __name__ == "__main__": main()