TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""TRL + Unsloth SFT training utilities."""
from __future__ import annotations
from dataclasses import dataclass
import json
import os
from pathlib import Path
from typing import Any
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from app.training.checkpointing import save_checkpoint
from app.training.lora_utils import build_lora_config
from app.training.lora_utils import build_qlora_config
from app.training.model_registry import register_model_run
from app.training.unsloth_loader import load_unsloth_model
@dataclass(slots=True)
class SFTRunConfig:
model_id: str
output_dir: Path
dataset_path: Path
max_seq_len: int = 1024
epochs: int = 1
learning_rate: float = 2e-5
batch_size: int = 2
max_steps: int = 30
use_unsloth: bool = True
allow_fallback: bool = False
def _to_text_record(example: dict[str, Any]) -> str:
prompt = example.get("prompt", {})
meds = prompt.get("medications", [])
candidates = prompt.get("candidates", prompt.get("candidate_set", []))
target = example.get("target_candidate_id", "cand_01")
return json.dumps(
{
"instruction": "Select the safest legal medication action candidate_id.",
"medications": meds,
"candidates": candidates,
"answer": target,
},
ensure_ascii=True,
)
def _load_examples(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
return []
def _fallback_train(config: SFTRunConfig, examples: list[dict[str, Any]]) -> dict[str, Any]:
if not examples:
out = {
"status": "no_data",
"backend": "fallback_sklearn",
"examples_used": 0,
"model_id": config.model_id,
}
save_checkpoint(config.output_dir / "sft_checkpoint.json", out)
return out
def _features(example: dict[str, Any]) -> list[float]:
prompt = example.get("prompt", {})
meds = prompt.get("medications", [])
candidates = prompt.get("candidates", prompt.get("candidate_set", []))
uncertainty = float(prompt.get("uncertainty", 0.5))
severe_pairs = float(prompt.get("severe_pair_count", 0.0))
return [float(len(meds)), float(len(candidates)), uncertainty, severe_pairs]
x = np.array([_features(example) for example in examples], dtype=float)
y = np.array([hash(str(example.get("target_candidate_id", "cand_00"))) % 97 for example in examples], dtype=int)
model = RandomForestClassifier(n_estimators=120, random_state=42)
model.fit(x, y)
acc = float((model.predict(x) == y).mean())
artifact = config.output_dir / "sft_policy_fallback.json"
artifact.write_text(json.dumps({"train_accuracy": round(acc, 4)}, ensure_ascii=True, indent=2), encoding="utf-8")
out = {
"status": "ok",
"backend": "fallback_sklearn",
"examples_used": len(examples),
"train_accuracy": round(acc, 4),
"artifact_path": str(artifact),
"model_id": config.model_id,
}
save_checkpoint(config.output_dir / "sft_checkpoint.json", out)
return out
def run_sft_trl(config: SFTRunConfig) -> dict[str, Any]:
config.output_dir.mkdir(parents=True, exist_ok=True)
examples = _load_examples(config.dataset_path)
if not examples:
result = {
"status": "no_data",
"backend": "trl_unsloth",
"examples_used": 0,
"model_id": config.model_id,
}
save_checkpoint(config.output_dir / "sft_checkpoint.json", result)
return result
unsloth_probe = load_unsloth_model(config.model_id) if config.use_unsloth else {"available": False}
try:
from datasets import Dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
except Exception as exc: # noqa: BLE001
if not config.allow_fallback:
raise RuntimeError(
"TRL SFTTrainer import failed. Training is configured to require Hugging Face TRL. "
f"Install TRL dependencies or rerun with allow_fallback=True. Details: {exc}"
) from exc
result = _fallback_train(config=config, examples=examples)
result["trl_error"] = str(exc)
return result
dataset = Dataset.from_dict({"text": [_to_text_record(item) for item in examples]})
try:
model = None
tokenizer = None
backend = "trl_transformers"
if config.use_unsloth:
try:
from unsloth import FastLanguageModel # type: ignore
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.model_id,
max_seq_length=config.max_seq_len,
dtype=None,
load_in_4bit=True,
)
qlora = build_qlora_config(rank=16, alpha=32, dropout=0.05)
model = FastLanguageModel.get_peft_model(
model,
r=int(qlora["r"]),
target_modules=["q_proj", "v_proj"],
lora_alpha=int(qlora["lora_alpha"]),
lora_dropout=float(qlora["lora_dropout"]),
bias="none",
use_gradient_checkpointing="unsloth",
)
backend = "trl_unsloth"
except Exception:
model = None
tokenizer = None
if model is None or tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(config.model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
config.model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
report_to = []
if os.getenv("WANDB_API_KEY"):
try:
import wandb # noqa: F401
report_to = ["wandb"]
except Exception:
report_to = []
lora_cfg = LoraConfig(**build_lora_config(rank=16, alpha=32, dropout=0.05))
args = SFTConfig(
output_dir=str(config.output_dir / "sft_artifacts"),
per_device_train_batch_size=config.batch_size,
gradient_accumulation_steps=1,
learning_rate=config.learning_rate,
num_train_epochs=float(config.epochs),
max_steps=config.max_steps,
logging_steps=1,
save_steps=max(1, config.max_steps),
report_to=report_to,
remove_unused_columns=False,
dataset_text_field="text",
max_length=config.max_seq_len,
fp16=torch.cuda.is_available(),
use_cpu=not torch.cuda.is_available(),
)
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
processing_class=tokenizer,
peft_config=None if backend == "trl_unsloth" else lora_cfg,
)
train_output = trainer.train()
trainer.save_model(str(config.output_dir / "sft_adapter"))
tokenizer.save_pretrained(str(config.output_dir / "sft_adapter"))
sample_rows = [_to_text_record(item) for item in examples[:5]]
generations = []
for row in sample_rows:
generations.append({"prompt": row[:240], "generation": "<stored_with_training_artifacts>", "backend": backend})
(config.output_dir / "sft_generations.json").write_text(
json.dumps(generations, ensure_ascii=True, indent=2), encoding="utf-8"
)
result = {
"status": "ok",
"backend": backend,
"examples_used": len(examples),
"model_id": config.model_id,
"unsloth_available": bool(unsloth_probe.get("available", False)),
"train_runtime": float(getattr(train_output, "metrics", {}).get("train_runtime", 0.0)),
"train_loss": float(getattr(train_output, "metrics", {}).get("train_loss", 0.0)),
"artifact_path": str(config.output_dir / "sft_adapter"),
}
save_checkpoint(config.output_dir / "sft_checkpoint.json", result)
register_model_run(
config.output_dir / "model_registry.json",
{
"stage": "sft",
"model_id": config.model_id,
"backend": backend,
"artifact_path": str(config.output_dir / "sft_adapter"),
"examples_used": len(examples),
},
)
return result
except Exception as exc: # noqa: BLE001
if not config.allow_fallback:
raise RuntimeError(
"TRL SFTTrainer runtime failed. Training is configured to require Hugging Face TRL. "
f"Fix the TRL runtime issue or rerun with allow_fallback=True. Details: {exc}"
) from exc
result = _fallback_train(config=config, examples=examples)
result["trl_runtime_error"] = str(exc)
return result