#!/usr/bin/env python3 """Safely merge LoRA adapters into a base model artifact.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import Any def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Safely merge PEFT adapter into base model.") parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") parser.add_argument("--base-model", default="") parser.add_argument("--output-dir", default="checkpoints/merged") parser.add_argument("--merge-dtype", choices=["float16", "bfloat16", "float32"], default="float16") parser.add_argument("--device-map", default="auto") parser.add_argument("--load-in-4bit", action="store_true") parser.add_argument("--allow-unsafe-merge", action="store_true") return parser.parse_args() def _resolve_dtype(name: str): import torch return { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, }[name] def _discover_base_model(adapter_dir: Path) -> str: cfg_path = adapter_dir / "adapter_config.json" if not cfg_path.exists(): return "" try: payload = json.loads(cfg_path.read_text(encoding="utf-8")) except json.JSONDecodeError: return "" base = payload.get("base_model_name_or_path") return str(base) if isinstance(base, str) else "" def _write_report(path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") def main() -> None: args = parse_args() root = Path(__file__).resolve().parents[1] adapter_dir = (root / args.adapter_dir).resolve() output_dir = (root / args.output_dir).resolve() report_path = output_dir / "merge_report.json" if not adapter_dir.exists(): raise SystemExit(f"adapter_dir_not_found:{adapter_dir}") base_model = args.base_model.strip() or _discover_base_model(adapter_dir) if not base_model: raise SystemExit( "base_model_not_found: pass --base-model or ensure adapter_config.json has base_model_name_or_path" ) if args.load_in_4bit and not args.allow_unsafe_merge: raise SystemExit( "unsafe_merge_blocked: refusing naive 4bit merge. Re-run without --load-in-4bit " "or pass --allow-unsafe-merge if you accept degraded fidelity risk." ) import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer model_kwargs: dict[str, Any] = { "device_map": args.device_map, "low_cpu_mem_usage": True, "torch_dtype": _resolve_dtype(args.merge_dtype), } if args.load_in_4bit: from transformers import BitsAndBytesConfig model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) model_kwargs.pop("torch_dtype", None) base_model_obj = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained(base_model) peft_model = PeftModel.from_pretrained(base_model_obj, str(adapter_dir)) merged = peft_model.merge_and_unload(progressbar=False) output_dir.mkdir(parents=True, exist_ok=True) merged.save_pretrained(str(output_dir), safe_serialization=True) tokenizer.save_pretrained(str(output_dir)) param_count = sum(param.numel() for param in merged.parameters()) payload = { "status": "ok", "adapter_dir": str(adapter_dir), "output_dir": str(output_dir), "base_model": base_model, "merge_dtype": args.merge_dtype, "load_in_4bit": bool(args.load_in_4bit), "unsafe_override": bool(args.allow_unsafe_merge), "parameters": int(param_count), "precision_warning": ( "4bit merge override enabled; validate numerics before deployment." if args.load_in_4bit else "none" ), } _write_report(report_path, payload) print("merge_done") if __name__ == "__main__": main()