Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Safely merge LoRA adapters into a base model artifact.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Safely merge PEFT adapter into base model.") | |
| parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") | |
| parser.add_argument("--base-model", default="") | |
| parser.add_argument("--output-dir", default="checkpoints/merged") | |
| parser.add_argument("--merge-dtype", choices=["float16", "bfloat16", "float32"], default="float16") | |
| parser.add_argument("--device-map", default="auto") | |
| parser.add_argument("--load-in-4bit", action="store_true") | |
| parser.add_argument("--allow-unsafe-merge", action="store_true") | |
| return parser.parse_args() | |
| def _resolve_dtype(name: str): | |
| import torch | |
| return { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| }[name] | |
| def _discover_base_model(adapter_dir: Path) -> str: | |
| cfg_path = adapter_dir / "adapter_config.json" | |
| if not cfg_path.exists(): | |
| return "" | |
| try: | |
| payload = json.loads(cfg_path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError: | |
| return "" | |
| base = payload.get("base_model_name_or_path") | |
| return str(base) if isinstance(base, str) else "" | |
| def _write_report(path: Path, payload: dict[str, Any]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") | |
| def main() -> None: | |
| args = parse_args() | |
| root = Path(__file__).resolve().parents[1] | |
| adapter_dir = (root / args.adapter_dir).resolve() | |
| output_dir = (root / args.output_dir).resolve() | |
| report_path = output_dir / "merge_report.json" | |
| if not adapter_dir.exists(): | |
| raise SystemExit(f"adapter_dir_not_found:{adapter_dir}") | |
| base_model = args.base_model.strip() or _discover_base_model(adapter_dir) | |
| if not base_model: | |
| raise SystemExit( | |
| "base_model_not_found: pass --base-model or ensure adapter_config.json has base_model_name_or_path" | |
| ) | |
| if args.load_in_4bit and not args.allow_unsafe_merge: | |
| raise SystemExit( | |
| "unsafe_merge_blocked: refusing naive 4bit merge. Re-run without --load-in-4bit " | |
| "or pass --allow-unsafe-merge if you accept degraded fidelity risk." | |
| ) | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_kwargs: dict[str, Any] = { | |
| "device_map": args.device_map, | |
| "low_cpu_mem_usage": True, | |
| "torch_dtype": _resolve_dtype(args.merge_dtype), | |
| } | |
| if args.load_in_4bit: | |
| from transformers import BitsAndBytesConfig | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) | |
| model_kwargs.pop("torch_dtype", None) | |
| base_model_obj = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| peft_model = PeftModel.from_pretrained(base_model_obj, str(adapter_dir)) | |
| merged = peft_model.merge_and_unload(progressbar=False) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| merged.save_pretrained(str(output_dir), safe_serialization=True) | |
| tokenizer.save_pretrained(str(output_dir)) | |
| param_count = sum(param.numel() for param in merged.parameters()) | |
| payload = { | |
| "status": "ok", | |
| "adapter_dir": str(adapter_dir), | |
| "output_dir": str(output_dir), | |
| "base_model": base_model, | |
| "merge_dtype": args.merge_dtype, | |
| "load_in_4bit": bool(args.load_in_4bit), | |
| "unsafe_override": bool(args.allow_unsafe_merge), | |
| "parameters": int(param_count), | |
| "precision_warning": ( | |
| "4bit merge override enabled; validate numerics before deployment." | |
| if args.load_in_4bit | |
| else "none" | |
| ), | |
| } | |
| _write_report(report_path, payload) | |
| print("merge_done") | |
| if __name__ == "__main__": | |
| main() | |