File size: 4,144 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
"""Safely merge LoRA adapters into a base model artifact."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Safely merge PEFT adapter into base model.")
    parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter")
    parser.add_argument("--base-model", default="")
    parser.add_argument("--output-dir", default="checkpoints/merged")
    parser.add_argument("--merge-dtype", choices=["float16", "bfloat16", "float32"], default="float16")
    parser.add_argument("--device-map", default="auto")
    parser.add_argument("--load-in-4bit", action="store_true")
    parser.add_argument("--allow-unsafe-merge", action="store_true")
    return parser.parse_args()


def _resolve_dtype(name: str):
    import torch

    return {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }[name]


def _discover_base_model(adapter_dir: Path) -> str:
    cfg_path = adapter_dir / "adapter_config.json"
    if not cfg_path.exists():
        return ""
    try:
        payload = json.loads(cfg_path.read_text(encoding="utf-8"))
    except json.JSONDecodeError:
        return ""
    base = payload.get("base_model_name_or_path")
    return str(base) if isinstance(base, str) else ""


def _write_report(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")


def main() -> None:
    args = parse_args()
    root = Path(__file__).resolve().parents[1]
    adapter_dir = (root / args.adapter_dir).resolve()
    output_dir = (root / args.output_dir).resolve()
    report_path = output_dir / "merge_report.json"

    if not adapter_dir.exists():
        raise SystemExit(f"adapter_dir_not_found:{adapter_dir}")

    base_model = args.base_model.strip() or _discover_base_model(adapter_dir)
    if not base_model:
        raise SystemExit(
            "base_model_not_found: pass --base-model or ensure adapter_config.json has base_model_name_or_path"
        )

    if args.load_in_4bit and not args.allow_unsafe_merge:
        raise SystemExit(
            "unsafe_merge_blocked: refusing naive 4bit merge. Re-run without --load-in-4bit "
            "or pass --allow-unsafe-merge if you accept degraded fidelity risk."
        )

    import torch
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer

    model_kwargs: dict[str, Any] = {
        "device_map": args.device_map,
        "low_cpu_mem_usage": True,
        "torch_dtype": _resolve_dtype(args.merge_dtype),
    }

    if args.load_in_4bit:
        from transformers import BitsAndBytesConfig

        model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
        model_kwargs.pop("torch_dtype", None)

    base_model_obj = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    peft_model = PeftModel.from_pretrained(base_model_obj, str(adapter_dir))
    merged = peft_model.merge_and_unload(progressbar=False)

    output_dir.mkdir(parents=True, exist_ok=True)
    merged.save_pretrained(str(output_dir), safe_serialization=True)
    tokenizer.save_pretrained(str(output_dir))

    param_count = sum(param.numel() for param in merged.parameters())
    payload = {
        "status": "ok",
        "adapter_dir": str(adapter_dir),
        "output_dir": str(output_dir),
        "base_model": base_model,
        "merge_dtype": args.merge_dtype,
        "load_in_4bit": bool(args.load_in_4bit),
        "unsafe_override": bool(args.allow_unsafe_merge),
        "parameters": int(param_count),
        "precision_warning": (
            "4bit merge override enabled; validate numerics before deployment."
            if args.load_in_4bit
            else "none"
        ),
    }
    _write_report(report_path, payload)
    print("merge_done")


if __name__ == "__main__":
    main()