"""TRL + Unsloth SFT training utilities.""" from __future__ import annotations from dataclasses import dataclass import json import os from pathlib import Path from typing import Any import numpy as np from sklearn.ensemble import RandomForestClassifier from app.training.checkpointing import save_checkpoint from app.training.lora_utils import build_lora_config from app.training.lora_utils import build_qlora_config from app.training.model_registry import register_model_run from app.training.unsloth_loader import load_unsloth_model @dataclass(slots=True) class SFTRunConfig: model_id: str output_dir: Path dataset_path: Path max_seq_len: int = 1024 epochs: int = 1 learning_rate: float = 2e-5 batch_size: int = 2 max_steps: int = 30 use_unsloth: bool = True allow_fallback: bool = False def _to_text_record(example: dict[str, Any]) -> str: prompt = example.get("prompt", {}) meds = prompt.get("medications", []) candidates = prompt.get("candidates", prompt.get("candidate_set", [])) target = example.get("target_candidate_id", "cand_01") return json.dumps( { "instruction": "Select the safest legal medication action candidate_id.", "medications": meds, "candidates": candidates, "answer": target, }, ensure_ascii=True, ) def _load_examples(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] payload = json.loads(path.read_text(encoding="utf-8")) if isinstance(payload, list): return [item for item in payload if isinstance(item, dict)] return [] def _fallback_train(config: SFTRunConfig, examples: list[dict[str, Any]]) -> dict[str, Any]: if not examples: out = { "status": "no_data", "backend": "fallback_sklearn", "examples_used": 0, "model_id": config.model_id, } save_checkpoint(config.output_dir / "sft_checkpoint.json", out) return out def _features(example: dict[str, Any]) -> list[float]: prompt = example.get("prompt", {}) meds = prompt.get("medications", []) candidates = prompt.get("candidates", prompt.get("candidate_set", [])) uncertainty = float(prompt.get("uncertainty", 0.5)) severe_pairs = float(prompt.get("severe_pair_count", 0.0)) return [float(len(meds)), float(len(candidates)), uncertainty, severe_pairs] x = np.array([_features(example) for example in examples], dtype=float) y = np.array([hash(str(example.get("target_candidate_id", "cand_00"))) % 97 for example in examples], dtype=int) model = RandomForestClassifier(n_estimators=120, random_state=42) model.fit(x, y) acc = float((model.predict(x) == y).mean()) artifact = config.output_dir / "sft_policy_fallback.json" artifact.write_text(json.dumps({"train_accuracy": round(acc, 4)}, ensure_ascii=True, indent=2), encoding="utf-8") out = { "status": "ok", "backend": "fallback_sklearn", "examples_used": len(examples), "train_accuracy": round(acc, 4), "artifact_path": str(artifact), "model_id": config.model_id, } save_checkpoint(config.output_dir / "sft_checkpoint.json", out) return out def run_sft_trl(config: SFTRunConfig) -> dict[str, Any]: config.output_dir.mkdir(parents=True, exist_ok=True) examples = _load_examples(config.dataset_path) if not examples: result = { "status": "no_data", "backend": "trl_unsloth", "examples_used": 0, "model_id": config.model_id, } save_checkpoint(config.output_dir / "sft_checkpoint.json", result) return result unsloth_probe = load_unsloth_model(config.model_id) if config.use_unsloth else {"available": False} try: from datasets import Dataset from peft import LoraConfig import torch from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTConfig, SFTTrainer except Exception as exc: # noqa: BLE001 if not config.allow_fallback: raise RuntimeError( "TRL SFTTrainer import failed. Training is configured to require Hugging Face TRL. " f"Install TRL dependencies or rerun with allow_fallback=True. Details: {exc}" ) from exc result = _fallback_train(config=config, examples=examples) result["trl_error"] = str(exc) return result dataset = Dataset.from_dict({"text": [_to_text_record(item) for item in examples]}) try: model = None tokenizer = None backend = "trl_transformers" if config.use_unsloth: try: from unsloth import FastLanguageModel # type: ignore model, tokenizer = FastLanguageModel.from_pretrained( model_name=config.model_id, max_seq_length=config.max_seq_len, dtype=None, load_in_4bit=True, ) qlora = build_qlora_config(rank=16, alpha=32, dropout=0.05) model = FastLanguageModel.get_peft_model( model, r=int(qlora["r"]), target_modules=["q_proj", "v_proj"], lora_alpha=int(qlora["lora_alpha"]), lora_dropout=float(qlora["lora_dropout"]), bias="none", use_gradient_checkpointing="unsloth", ) backend = "trl_unsloth" except Exception: model = None tokenizer = None if model is None or tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(config.model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( config.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, ) report_to = [] if os.getenv("WANDB_API_KEY"): try: import wandb # noqa: F401 report_to = ["wandb"] except Exception: report_to = [] lora_cfg = LoraConfig(**build_lora_config(rank=16, alpha=32, dropout=0.05)) args = SFTConfig( output_dir=str(config.output_dir / "sft_artifacts"), per_device_train_batch_size=config.batch_size, gradient_accumulation_steps=1, learning_rate=config.learning_rate, num_train_epochs=float(config.epochs), max_steps=config.max_steps, logging_steps=1, save_steps=max(1, config.max_steps), report_to=report_to, remove_unused_columns=False, dataset_text_field="text", max_length=config.max_seq_len, fp16=torch.cuda.is_available(), use_cpu=not torch.cuda.is_available(), ) trainer = SFTTrainer( model=model, args=args, train_dataset=dataset, processing_class=tokenizer, peft_config=None if backend == "trl_unsloth" else lora_cfg, ) train_output = trainer.train() trainer.save_model(str(config.output_dir / "sft_adapter")) tokenizer.save_pretrained(str(config.output_dir / "sft_adapter")) sample_rows = [_to_text_record(item) for item in examples[:5]] generations = [] for row in sample_rows: generations.append({"prompt": row[:240], "generation": "", "backend": backend}) (config.output_dir / "sft_generations.json").write_text( json.dumps(generations, ensure_ascii=True, indent=2), encoding="utf-8" ) result = { "status": "ok", "backend": backend, "examples_used": len(examples), "model_id": config.model_id, "unsloth_available": bool(unsloth_probe.get("available", False)), "train_runtime": float(getattr(train_output, "metrics", {}).get("train_runtime", 0.0)), "train_loss": float(getattr(train_output, "metrics", {}).get("train_loss", 0.0)), "artifact_path": str(config.output_dir / "sft_adapter"), } save_checkpoint(config.output_dir / "sft_checkpoint.json", result) register_model_run( config.output_dir / "model_registry.json", { "stage": "sft", "model_id": config.model_id, "backend": backend, "artifact_path": str(config.output_dir / "sft_adapter"), "examples_used": len(examples), }, ) return result except Exception as exc: # noqa: BLE001 if not config.allow_fallback: raise RuntimeError( "TRL SFTTrainer runtime failed. Training is configured to require Hugging Face TRL. " f"Fix the TRL runtime issue or rerun with allow_fallback=True. Details: {exc}" ) from exc result = _fallback_train(config=config, examples=examples) result["trl_runtime_error"] = str(exc) return result