Spaces:
Paused
Paused
Add LoRA adapter merge utility
Browse files- scripts/merge_lora_adapters.py +215 -0
scripts/merge_lora_adapters.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Merge PEFT/LoRA adapters with linear model soup or SLERP.
|
| 3 |
+
|
| 4 |
+
This is intended for adapters that share the same base model and LoRA config,
|
| 5 |
+
for example B2, DPO, and PPO adapters trained from the same LLaVA-Med base.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import shutil
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from safetensors.torch import load_file, save_file
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
COPY_FILES = [
|
| 21 |
+
"adapter_config.json",
|
| 22 |
+
"tokenizer_config.json",
|
| 23 |
+
"tokenizer.json",
|
| 24 |
+
"processor_config.json",
|
| 25 |
+
"chat_template.jinja",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_weighted_adapter(raw: str) -> tuple[Path, float]:
|
| 30 |
+
if "=" not in raw:
|
| 31 |
+
return Path(raw), 1.0
|
| 32 |
+
path, weight = raw.rsplit("=", 1)
|
| 33 |
+
return Path(path), float(weight)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize_weights(weights: list[float]) -> list[float]:
|
| 37 |
+
total = sum(weights)
|
| 38 |
+
if not math.isfinite(total) or total <= 0:
|
| 39 |
+
raise ValueError("Adapter weights must sum to a positive finite value.")
|
| 40 |
+
return [weight / total for weight in weights]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def adapter_model_path(adapter_dir: Path) -> Path:
|
| 44 |
+
path = adapter_dir / "adapter_model.safetensors"
|
| 45 |
+
if not path.exists():
|
| 46 |
+
raise FileNotFoundError(f"Missing adapter weights: {path}")
|
| 47 |
+
return path
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_adapter_state(adapter_dir: Path) -> dict[str, torch.Tensor]:
|
| 51 |
+
return load_file(str(adapter_model_path(adapter_dir)), device="cpu")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def validate_states(states: list[dict[str, torch.Tensor]]) -> list[str]:
|
| 55 |
+
keys = sorted(states[0].keys())
|
| 56 |
+
expected = set(keys)
|
| 57 |
+
for idx, state in enumerate(states[1:], start=2):
|
| 58 |
+
if set(state.keys()) != expected:
|
| 59 |
+
missing = sorted(expected - set(state.keys()))[:5]
|
| 60 |
+
extra = sorted(set(state.keys()) - expected)[:5]
|
| 61 |
+
raise ValueError(f"Adapter {idx} has different keys. Missing={missing}, extra={extra}")
|
| 62 |
+
for key in keys:
|
| 63 |
+
if state[key].shape != states[0][key].shape:
|
| 64 |
+
raise ValueError(f"Shape mismatch for {key}: {state[key].shape} != {states[0][key].shape}")
|
| 65 |
+
return keys
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def linear_soup(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]:
|
| 69 |
+
merged: dict[str, torch.Tensor] = {}
|
| 70 |
+
for key in keys:
|
| 71 |
+
ref = states[0][key]
|
| 72 |
+
tensor = torch.zeros_like(ref, dtype=torch.float32)
|
| 73 |
+
for state, weight in zip(states, weights):
|
| 74 |
+
tensor += state[key].float() * weight
|
| 75 |
+
merged[key] = tensor.to(dtype=ref.dtype)
|
| 76 |
+
return merged
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def flatten_state(state: dict[str, torch.Tensor], keys: list[str]) -> torch.Tensor:
|
| 80 |
+
return torch.cat([state[key].float().reshape(-1) for key in keys])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def unflatten_state(vector: torch.Tensor, reference: dict[str, torch.Tensor], keys: list[str]) -> dict[str, torch.Tensor]:
|
| 84 |
+
merged: dict[str, torch.Tensor] = {}
|
| 85 |
+
offset = 0
|
| 86 |
+
for key in keys:
|
| 87 |
+
ref = reference[key]
|
| 88 |
+
size = ref.numel()
|
| 89 |
+
merged[key] = vector[offset : offset + size].reshape(ref.shape).to(dtype=ref.dtype)
|
| 90 |
+
offset += size
|
| 91 |
+
return merged
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def slerp_vectors(a: torch.Tensor, b: torch.Tensor, t: float, eps: float = 1e-8) -> torch.Tensor:
|
| 95 |
+
a_norm = torch.linalg.vector_norm(a)
|
| 96 |
+
b_norm = torch.linalg.vector_norm(b)
|
| 97 |
+
if a_norm < eps or b_norm < eps:
|
| 98 |
+
return (1.0 - t) * a + t * b
|
| 99 |
+
|
| 100 |
+
a_unit = a / a_norm
|
| 101 |
+
b_unit = b / b_norm
|
| 102 |
+
dot = torch.clamp(torch.dot(a_unit, b_unit), -1.0, 1.0)
|
| 103 |
+
omega = torch.acos(dot)
|
| 104 |
+
sin_omega = torch.sin(omega)
|
| 105 |
+
if torch.abs(sin_omega) < eps:
|
| 106 |
+
return (1.0 - t) * a + t * b
|
| 107 |
+
|
| 108 |
+
direction = (torch.sin((1.0 - t) * omega) / sin_omega) * a + (torch.sin(t * omega) / sin_omega) * b
|
| 109 |
+
target_norm = (1.0 - t) * a_norm + t * b_norm
|
| 110 |
+
direction_norm = torch.linalg.vector_norm(direction)
|
| 111 |
+
if direction_norm < eps:
|
| 112 |
+
return direction
|
| 113 |
+
return direction / direction_norm * target_norm
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def iterative_slerp(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]:
|
| 117 |
+
vectors = [flatten_state(state, keys) for state in states]
|
| 118 |
+
merged = vectors[0]
|
| 119 |
+
accumulated = weights[0]
|
| 120 |
+
for vector, weight in zip(vectors[1:], weights[1:]):
|
| 121 |
+
t = weight / (accumulated + weight)
|
| 122 |
+
merged = slerp_vectors(merged, vector, t)
|
| 123 |
+
accumulated += weight
|
| 124 |
+
return unflatten_state(merged, states[0], keys)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def copy_adapter_files(reference_dir: Path, output_dir: Path) -> None:
|
| 128 |
+
for filename in COPY_FILES:
|
| 129 |
+
src = reference_dir / filename
|
| 130 |
+
if src.exists():
|
| 131 |
+
shutil.copy2(src, output_dir / filename)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def write_metadata(
|
| 135 |
+
output_dir: Path,
|
| 136 |
+
method: str,
|
| 137 |
+
adapters: list[Path],
|
| 138 |
+
weights: list[float],
|
| 139 |
+
base_model: str,
|
| 140 |
+
) -> None:
|
| 141 |
+
metadata = {
|
| 142 |
+
"method": method,
|
| 143 |
+
"base_model": base_model,
|
| 144 |
+
"adapters": [str(path) for path in adapters],
|
| 145 |
+
"weights": weights,
|
| 146 |
+
}
|
| 147 |
+
(output_dir / "merge_config.json").write_text(json.dumps(metadata, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 148 |
+
readme = [
|
| 149 |
+
"---",
|
| 150 |
+
"library_name: peft",
|
| 151 |
+
"tags:",
|
| 152 |
+
"- peft",
|
| 153 |
+
"- lora",
|
| 154 |
+
"- visual-question-answering",
|
| 155 |
+
f"base_model: {base_model}",
|
| 156 |
+
"---",
|
| 157 |
+
"",
|
| 158 |
+
"# Medical VQA Merged Adapter",
|
| 159 |
+
"",
|
| 160 |
+
f"Merge method: `{method}`",
|
| 161 |
+
"",
|
| 162 |
+
"| Adapter | Weight |",
|
| 163 |
+
"|---|---:|",
|
| 164 |
+
]
|
| 165 |
+
for path, weight in zip(adapters, weights):
|
| 166 |
+
readme.append(f"| `{path}` | {weight:.4f} |")
|
| 167 |
+
readme.extend(
|
| 168 |
+
[
|
| 169 |
+
"",
|
| 170 |
+
"This adapter was created by merging fine-tuned LoRA adapters without additional training.",
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
(output_dir / "README.md").write_text("\n".join(readme) + "\n", encoding="utf-8")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def main() -> None:
|
| 177 |
+
parser = argparse.ArgumentParser(description="Merge PEFT LoRA adapters via model soup or SLERP.")
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--adapter",
|
| 180 |
+
action="append",
|
| 181 |
+
required=True,
|
| 182 |
+
help="Adapter directory, optionally with weight: path=0.4. Repeat for each adapter.",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument("--output", required=True, help="Output adapter directory.")
|
| 185 |
+
parser.add_argument("--method", choices=["linear", "slerp"], default="linear")
|
| 186 |
+
parser.add_argument("--base-model", default="chaoyinshe/llava-med-v1.5-mistral-7b-hf")
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--reference-index",
|
| 189 |
+
type=int,
|
| 190 |
+
default=0,
|
| 191 |
+
help="Adapter index used as source for adapter_config/tokenizer files.",
|
| 192 |
+
)
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
adapters, raw_weights = zip(*(parse_weighted_adapter(raw) for raw in args.adapter))
|
| 196 |
+
adapters = [path.expanduser().resolve() for path in adapters]
|
| 197 |
+
weights = normalize_weights(list(raw_weights))
|
| 198 |
+
output_dir = Path(args.output).expanduser().resolve()
|
| 199 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 200 |
+
|
| 201 |
+
if not 0 <= args.reference_index < len(adapters):
|
| 202 |
+
raise ValueError("--reference-index is out of range.")
|
| 203 |
+
|
| 204 |
+
states = [load_adapter_state(path) for path in adapters]
|
| 205 |
+
keys = validate_states(states)
|
| 206 |
+
merged = linear_soup(states, weights, keys) if args.method == "linear" else iterative_slerp(states, weights, keys)
|
| 207 |
+
|
| 208 |
+
save_file(merged, str(output_dir / "adapter_model.safetensors"))
|
| 209 |
+
copy_adapter_files(adapters[args.reference_index], output_dir)
|
| 210 |
+
write_metadata(output_dir, args.method, adapters, weights, args.base_model)
|
| 211 |
+
print(f"Saved merged adapter to: {output_dir}")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
main()
|