| |
| """Safely merge LoRA adapters into a base model artifact.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Safely merge PEFT adapter into base model.") |
| parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") |
| parser.add_argument("--base-model", default="") |
| parser.add_argument("--output-dir", default="checkpoints/merged") |
| parser.add_argument("--merge-dtype", choices=["float16", "bfloat16", "float32"], default="float16") |
| parser.add_argument("--device-map", default="auto") |
| parser.add_argument("--load-in-4bit", action="store_true") |
| parser.add_argument("--allow-unsafe-merge", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def _resolve_dtype(name: str): |
| import torch |
|
|
| return { |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "float32": torch.float32, |
| }[name] |
|
|
|
|
| def _discover_base_model(adapter_dir: Path) -> str: |
| cfg_path = adapter_dir / "adapter_config.json" |
| if not cfg_path.exists(): |
| return "" |
| try: |
| payload = json.loads(cfg_path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return "" |
| base = payload.get("base_model_name_or_path") |
| return str(base) if isinstance(base, str) else "" |
|
|
|
|
| def _write_report(path: Path, payload: dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| root = Path(__file__).resolve().parents[1] |
| adapter_dir = (root / args.adapter_dir).resolve() |
| output_dir = (root / args.output_dir).resolve() |
| report_path = output_dir / "merge_report.json" |
|
|
| if not adapter_dir.exists(): |
| raise SystemExit(f"adapter_dir_not_found:{adapter_dir}") |
|
|
| base_model = args.base_model.strip() or _discover_base_model(adapter_dir) |
| if not base_model: |
| raise SystemExit( |
| "base_model_not_found: pass --base-model or ensure adapter_config.json has base_model_name_or_path" |
| ) |
|
|
| if args.load_in_4bit and not args.allow_unsafe_merge: |
| raise SystemExit( |
| "unsafe_merge_blocked: refusing naive 4bit merge. Re-run without --load-in-4bit " |
| "or pass --allow-unsafe-merge if you accept degraded fidelity risk." |
| ) |
|
|
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| model_kwargs: dict[str, Any] = { |
| "device_map": args.device_map, |
| "low_cpu_mem_usage": True, |
| "torch_dtype": _resolve_dtype(args.merge_dtype), |
| } |
|
|
| if args.load_in_4bit: |
| from transformers import BitsAndBytesConfig |
|
|
| model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) |
| model_kwargs.pop("torch_dtype", None) |
|
|
| base_model_obj = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) |
| tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
| peft_model = PeftModel.from_pretrained(base_model_obj, str(adapter_dir)) |
| merged = peft_model.merge_and_unload(progressbar=False) |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| merged.save_pretrained(str(output_dir), safe_serialization=True) |
| tokenizer.save_pretrained(str(output_dir)) |
|
|
| param_count = sum(param.numel() for param in merged.parameters()) |
| payload = { |
| "status": "ok", |
| "adapter_dir": str(adapter_dir), |
| "output_dir": str(output_dir), |
| "base_model": base_model, |
| "merge_dtype": args.merge_dtype, |
| "load_in_4bit": bool(args.load_in_4bit), |
| "unsafe_override": bool(args.allow_unsafe_merge), |
| "parameters": int(param_count), |
| "precision_warning": ( |
| "4bit merge override enabled; validate numerics before deployment." |
| if args.load_in_4bit |
| else "none" |
| ), |
| } |
| _write_report(report_path, payload) |
| print("merge_done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|