Spaces:
Running
Running
| """TRL + Unsloth SFT training utilities.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| from app.training.checkpointing import save_checkpoint | |
| from app.training.lora_utils import build_lora_config | |
| from app.training.lora_utils import build_qlora_config | |
| from app.training.model_registry import register_model_run | |
| from app.training.unsloth_loader import load_unsloth_model | |
| class SFTRunConfig: | |
| model_id: str | |
| output_dir: Path | |
| dataset_path: Path | |
| max_seq_len: int = 1024 | |
| epochs: int = 1 | |
| learning_rate: float = 2e-5 | |
| batch_size: int = 2 | |
| max_steps: int = 30 | |
| use_unsloth: bool = True | |
| allow_fallback: bool = False | |
| def _to_text_record(example: dict[str, Any]) -> str: | |
| prompt = example.get("prompt", {}) | |
| meds = prompt.get("medications", []) | |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) | |
| target = example.get("target_candidate_id", "cand_01") | |
| return json.dumps( | |
| { | |
| "instruction": "Select the safest legal medication action candidate_id.", | |
| "medications": meds, | |
| "candidates": candidates, | |
| "answer": target, | |
| }, | |
| ensure_ascii=True, | |
| ) | |
| def _load_examples(path: Path) -> list[dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| if isinstance(payload, list): | |
| return [item for item in payload if isinstance(item, dict)] | |
| return [] | |
| def _fallback_train(config: SFTRunConfig, examples: list[dict[str, Any]]) -> dict[str, Any]: | |
| if not examples: | |
| out = { | |
| "status": "no_data", | |
| "backend": "fallback_sklearn", | |
| "examples_used": 0, | |
| "model_id": config.model_id, | |
| } | |
| save_checkpoint(config.output_dir / "sft_checkpoint.json", out) | |
| return out | |
| def _features(example: dict[str, Any]) -> list[float]: | |
| prompt = example.get("prompt", {}) | |
| meds = prompt.get("medications", []) | |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) | |
| uncertainty = float(prompt.get("uncertainty", 0.5)) | |
| severe_pairs = float(prompt.get("severe_pair_count", 0.0)) | |
| return [float(len(meds)), float(len(candidates)), uncertainty, severe_pairs] | |
| x = np.array([_features(example) for example in examples], dtype=float) | |
| y = np.array([hash(str(example.get("target_candidate_id", "cand_00"))) % 97 for example in examples], dtype=int) | |
| model = RandomForestClassifier(n_estimators=120, random_state=42) | |
| model.fit(x, y) | |
| acc = float((model.predict(x) == y).mean()) | |
| artifact = config.output_dir / "sft_policy_fallback.json" | |
| artifact.write_text(json.dumps({"train_accuracy": round(acc, 4)}, ensure_ascii=True, indent=2), encoding="utf-8") | |
| out = { | |
| "status": "ok", | |
| "backend": "fallback_sklearn", | |
| "examples_used": len(examples), | |
| "train_accuracy": round(acc, 4), | |
| "artifact_path": str(artifact), | |
| "model_id": config.model_id, | |
| } | |
| save_checkpoint(config.output_dir / "sft_checkpoint.json", out) | |
| return out | |
| def run_sft_trl(config: SFTRunConfig) -> dict[str, Any]: | |
| config.output_dir.mkdir(parents=True, exist_ok=True) | |
| examples = _load_examples(config.dataset_path) | |
| if not examples: | |
| result = { | |
| "status": "no_data", | |
| "backend": "trl_unsloth", | |
| "examples_used": 0, | |
| "model_id": config.model_id, | |
| } | |
| save_checkpoint(config.output_dir / "sft_checkpoint.json", result) | |
| return result | |
| unsloth_probe = load_unsloth_model(config.model_id) if config.use_unsloth else {"available": False} | |
| try: | |
| from datasets import Dataset | |
| from peft import LoraConfig | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTConfig, SFTTrainer | |
| except Exception as exc: # noqa: BLE001 | |
| if not config.allow_fallback: | |
| raise RuntimeError( | |
| "TRL SFTTrainer import failed. Training is configured to require Hugging Face TRL. " | |
| f"Install TRL dependencies or rerun with allow_fallback=True. Details: {exc}" | |
| ) from exc | |
| result = _fallback_train(config=config, examples=examples) | |
| result["trl_error"] = str(exc) | |
| return result | |
| dataset = Dataset.from_dict({"text": [_to_text_record(item) for item in examples]}) | |
| try: | |
| model = None | |
| tokenizer = None | |
| backend = "trl_transformers" | |
| if config.use_unsloth: | |
| try: | |
| from unsloth import FastLanguageModel # type: ignore | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=config.model_id, | |
| max_seq_length=config.max_seq_len, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| qlora = build_qlora_config(rank=16, alpha=32, dropout=0.05) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=int(qlora["r"]), | |
| target_modules=["q_proj", "v_proj"], | |
| lora_alpha=int(qlora["lora_alpha"]), | |
| lora_dropout=float(qlora["lora_dropout"]), | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| backend = "trl_unsloth" | |
| except Exception: | |
| model = None | |
| tokenizer = None | |
| if model is None or tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_id) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| report_to = [] | |
| if os.getenv("WANDB_API_KEY"): | |
| try: | |
| import wandb # noqa: F401 | |
| report_to = ["wandb"] | |
| except Exception: | |
| report_to = [] | |
| lora_cfg = LoraConfig(**build_lora_config(rank=16, alpha=32, dropout=0.05)) | |
| args = SFTConfig( | |
| output_dir=str(config.output_dir / "sft_artifacts"), | |
| per_device_train_batch_size=config.batch_size, | |
| gradient_accumulation_steps=1, | |
| learning_rate=config.learning_rate, | |
| num_train_epochs=float(config.epochs), | |
| max_steps=config.max_steps, | |
| logging_steps=1, | |
| save_steps=max(1, config.max_steps), | |
| report_to=report_to, | |
| remove_unused_columns=False, | |
| dataset_text_field="text", | |
| max_length=config.max_seq_len, | |
| fp16=torch.cuda.is_available(), | |
| use_cpu=not torch.cuda.is_available(), | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| peft_config=None if backend == "trl_unsloth" else lora_cfg, | |
| ) | |
| train_output = trainer.train() | |
| trainer.save_model(str(config.output_dir / "sft_adapter")) | |
| tokenizer.save_pretrained(str(config.output_dir / "sft_adapter")) | |
| sample_rows = [_to_text_record(item) for item in examples[:5]] | |
| generations = [] | |
| for row in sample_rows: | |
| generations.append({"prompt": row[:240], "generation": "<stored_with_training_artifacts>", "backend": backend}) | |
| (config.output_dir / "sft_generations.json").write_text( | |
| json.dumps(generations, ensure_ascii=True, indent=2), encoding="utf-8" | |
| ) | |
| result = { | |
| "status": "ok", | |
| "backend": backend, | |
| "examples_used": len(examples), | |
| "model_id": config.model_id, | |
| "unsloth_available": bool(unsloth_probe.get("available", False)), | |
| "train_runtime": float(getattr(train_output, "metrics", {}).get("train_runtime", 0.0)), | |
| "train_loss": float(getattr(train_output, "metrics", {}).get("train_loss", 0.0)), | |
| "artifact_path": str(config.output_dir / "sft_adapter"), | |
| } | |
| save_checkpoint(config.output_dir / "sft_checkpoint.json", result) | |
| register_model_run( | |
| config.output_dir / "model_registry.json", | |
| { | |
| "stage": "sft", | |
| "model_id": config.model_id, | |
| "backend": backend, | |
| "artifact_path": str(config.output_dir / "sft_adapter"), | |
| "examples_used": len(examples), | |
| }, | |
| ) | |
| return result | |
| except Exception as exc: # noqa: BLE001 | |
| if not config.allow_fallback: | |
| raise RuntimeError( | |
| "TRL SFTTrainer runtime failed. Training is configured to require Hugging Face TRL. " | |
| f"Fix the TRL runtime issue or rerun with allow_fallback=True. Details: {exc}" | |
| ) from exc | |
| result = _fallback_train(config=config, examples=examples) | |
| result["trl_runtime_error"] = str(exc) | |
| return result | |