adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Safely merge LoRA adapters into a base model artifact."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Safely merge PEFT adapter into base model.")
parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter")
parser.add_argument("--base-model", default="")
parser.add_argument("--output-dir", default="checkpoints/merged")
parser.add_argument("--merge-dtype", choices=["float16", "bfloat16", "float32"], default="float16")
parser.add_argument("--device-map", default="auto")
parser.add_argument("--load-in-4bit", action="store_true")
parser.add_argument("--allow-unsafe-merge", action="store_true")
return parser.parse_args()
def _resolve_dtype(name: str):
import torch
return {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}[name]
def _discover_base_model(adapter_dir: Path) -> str:
cfg_path = adapter_dir / "adapter_config.json"
if not cfg_path.exists():
return ""
try:
payload = json.loads(cfg_path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return ""
base = payload.get("base_model_name_or_path")
return str(base) if isinstance(base, str) else ""
def _write_report(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
def main() -> None:
args = parse_args()
root = Path(__file__).resolve().parents[1]
adapter_dir = (root / args.adapter_dir).resolve()
output_dir = (root / args.output_dir).resolve()
report_path = output_dir / "merge_report.json"
if not adapter_dir.exists():
raise SystemExit(f"adapter_dir_not_found:{adapter_dir}")
base_model = args.base_model.strip() or _discover_base_model(adapter_dir)
if not base_model:
raise SystemExit(
"base_model_not_found: pass --base-model or ensure adapter_config.json has base_model_name_or_path"
)
if args.load_in_4bit and not args.allow_unsafe_merge:
raise SystemExit(
"unsafe_merge_blocked: refusing naive 4bit merge. Re-run without --load-in-4bit "
"or pass --allow-unsafe-merge if you accept degraded fidelity risk."
)
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
model_kwargs: dict[str, Any] = {
"device_map": args.device_map,
"low_cpu_mem_usage": True,
"torch_dtype": _resolve_dtype(args.merge_dtype),
}
if args.load_in_4bit:
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs.pop("torch_dtype", None)
base_model_obj = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model)
peft_model = PeftModel.from_pretrained(base_model_obj, str(adapter_dir))
merged = peft_model.merge_and_unload(progressbar=False)
output_dir.mkdir(parents=True, exist_ok=True)
merged.save_pretrained(str(output_dir), safe_serialization=True)
tokenizer.save_pretrained(str(output_dir))
param_count = sum(param.numel() for param in merged.parameters())
payload = {
"status": "ok",
"adapter_dir": str(adapter_dir),
"output_dir": str(output_dir),
"base_model": base_model,
"merge_dtype": args.merge_dtype,
"load_in_4bit": bool(args.load_in_4bit),
"unsafe_override": bool(args.allow_unsafe_merge),
"parameters": int(param_count),
"precision_warning": (
"4bit merge override enabled; validate numerics before deployment."
if args.load_in_4bit
else "none"
),
}
_write_report(report_path, payload)
print("merge_done")
if __name__ == "__main__":
main()