File size: 3,031 Bytes
bf9e424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Convert a PEFT LoRA adapter into the adapter format expected by mlx-lm."""

from __future__ import annotations

import argparse
import json
import re
import shutil
from pathlib import Path

import mlx.core as mx
from safetensors import safe_open


KEY_RE = re.compile(
    r"^base_model\.model\.model\.(?P<prefix>.+?)\.layers\."
    r"(?P<layer>\d+)\.(?P<module>.+?)\.lora_(?P<ab>[AB])\.weight$"
)


def main() -> None:
    parser = argparse.ArgumentParser(description="Convert PEFT LoRA adapter to MLX LoRA adapter.")
    parser.add_argument("peft_adapter", help="Path containing PEFT adapter_model.safetensors")
    parser.add_argument("mlx_adapter", help="Output path for MLX adapters.safetensors")
    args = parser.parse_args()

    peft_path = Path(args.peft_adapter)
    mlx_path = Path(args.mlx_adapter)
    mlx_path.mkdir(parents=True, exist_ok=True)

    peft_config = json.loads((peft_path / "adapter_config.json").read_text())
    rank = int(peft_config["r"])
    alpha = float(peft_config["lora_alpha"])
    scale = alpha / rank
    target_modules = list(peft_config["target_modules"])

    weights = {}
    layer_ids = set()
    module_keys = set()
    with safe_open(peft_path / "adapter_model.safetensors", framework="numpy") as handle:
        for key in handle.keys():
            match = KEY_RE.match(key)
            if not match:
                continue
            layer = int(match.group("layer"))
            module = match.group("module")
            ab = match.group("ab")
            layer_ids.add(layer)
            module_keys.add(module)
            tensor = handle.get_tensor(key)
            mlx_key = f"language_model.model.layers.{layer}.{module}.lora_{ab.lower()}"
            weights[mlx_key] = mx.array(tensor.T)

    if not weights:
        raise SystemExit(f"No PEFT LoRA weights found in {peft_path}")

    mx.save_safetensors(str(mlx_path / "adapters.safetensors"), weights)
    config = {
        "fine_tune_type": "lora",
        "num_layers": max(layer_ids) + 1,
        "lora_parameters": {
            "rank": rank,
            "scale": scale,
            "dropout": float(peft_config.get("lora_dropout", 0.0)),
            "keys": sorted(module_keys),
        },
    }
    (mlx_path / "adapter_config.json").write_text(json.dumps(config, indent=2) + "\n")

    for filename in [
        "tokenizer.json",
        "tokenizer_config.json",
        "chat_template.jinja",
        "processor_config.json",
        "README.md",
    ]:
        source = peft_path / filename
        if source.exists():
            shutil.copy2(source, mlx_path / filename)

    print(
        json.dumps(
            {
                "output": str(mlx_path),
                "weights": len(weights),
                "num_layers": config["num_layers"],
                "rank": rank,
                "scale": scale,
                "keys": sorted(module_keys),
                "target_modules": target_modules,
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()