SpringWang08 commited on
Commit
39688a1
·
verified ·
1 Parent(s): bd9252d

Add LoRA adapter merge utility

Browse files
Files changed (1) hide show
  1. scripts/merge_lora_adapters.py +215 -0
scripts/merge_lora_adapters.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Merge PEFT/LoRA adapters with linear model soup or SLERP.
3
+
4
+ This is intended for adapters that share the same base model and LoRA config,
5
+ for example B2, DPO, and PPO adapters trained from the same LLaVA-Med base.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import math
13
+ import shutil
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ from safetensors.torch import load_file, save_file
18
+
19
+
20
+ COPY_FILES = [
21
+ "adapter_config.json",
22
+ "tokenizer_config.json",
23
+ "tokenizer.json",
24
+ "processor_config.json",
25
+ "chat_template.jinja",
26
+ ]
27
+
28
+
29
+ def parse_weighted_adapter(raw: str) -> tuple[Path, float]:
30
+ if "=" not in raw:
31
+ return Path(raw), 1.0
32
+ path, weight = raw.rsplit("=", 1)
33
+ return Path(path), float(weight)
34
+
35
+
36
+ def normalize_weights(weights: list[float]) -> list[float]:
37
+ total = sum(weights)
38
+ if not math.isfinite(total) or total <= 0:
39
+ raise ValueError("Adapter weights must sum to a positive finite value.")
40
+ return [weight / total for weight in weights]
41
+
42
+
43
+ def adapter_model_path(adapter_dir: Path) -> Path:
44
+ path = adapter_dir / "adapter_model.safetensors"
45
+ if not path.exists():
46
+ raise FileNotFoundError(f"Missing adapter weights: {path}")
47
+ return path
48
+
49
+
50
+ def load_adapter_state(adapter_dir: Path) -> dict[str, torch.Tensor]:
51
+ return load_file(str(adapter_model_path(adapter_dir)), device="cpu")
52
+
53
+
54
+ def validate_states(states: list[dict[str, torch.Tensor]]) -> list[str]:
55
+ keys = sorted(states[0].keys())
56
+ expected = set(keys)
57
+ for idx, state in enumerate(states[1:], start=2):
58
+ if set(state.keys()) != expected:
59
+ missing = sorted(expected - set(state.keys()))[:5]
60
+ extra = sorted(set(state.keys()) - expected)[:5]
61
+ raise ValueError(f"Adapter {idx} has different keys. Missing={missing}, extra={extra}")
62
+ for key in keys:
63
+ if state[key].shape != states[0][key].shape:
64
+ raise ValueError(f"Shape mismatch for {key}: {state[key].shape} != {states[0][key].shape}")
65
+ return keys
66
+
67
+
68
+ def linear_soup(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]:
69
+ merged: dict[str, torch.Tensor] = {}
70
+ for key in keys:
71
+ ref = states[0][key]
72
+ tensor = torch.zeros_like(ref, dtype=torch.float32)
73
+ for state, weight in zip(states, weights):
74
+ tensor += state[key].float() * weight
75
+ merged[key] = tensor.to(dtype=ref.dtype)
76
+ return merged
77
+
78
+
79
+ def flatten_state(state: dict[str, torch.Tensor], keys: list[str]) -> torch.Tensor:
80
+ return torch.cat([state[key].float().reshape(-1) for key in keys])
81
+
82
+
83
+ def unflatten_state(vector: torch.Tensor, reference: dict[str, torch.Tensor], keys: list[str]) -> dict[str, torch.Tensor]:
84
+ merged: dict[str, torch.Tensor] = {}
85
+ offset = 0
86
+ for key in keys:
87
+ ref = reference[key]
88
+ size = ref.numel()
89
+ merged[key] = vector[offset : offset + size].reshape(ref.shape).to(dtype=ref.dtype)
90
+ offset += size
91
+ return merged
92
+
93
+
94
+ def slerp_vectors(a: torch.Tensor, b: torch.Tensor, t: float, eps: float = 1e-8) -> torch.Tensor:
95
+ a_norm = torch.linalg.vector_norm(a)
96
+ b_norm = torch.linalg.vector_norm(b)
97
+ if a_norm < eps or b_norm < eps:
98
+ return (1.0 - t) * a + t * b
99
+
100
+ a_unit = a / a_norm
101
+ b_unit = b / b_norm
102
+ dot = torch.clamp(torch.dot(a_unit, b_unit), -1.0, 1.0)
103
+ omega = torch.acos(dot)
104
+ sin_omega = torch.sin(omega)
105
+ if torch.abs(sin_omega) < eps:
106
+ return (1.0 - t) * a + t * b
107
+
108
+ direction = (torch.sin((1.0 - t) * omega) / sin_omega) * a + (torch.sin(t * omega) / sin_omega) * b
109
+ target_norm = (1.0 - t) * a_norm + t * b_norm
110
+ direction_norm = torch.linalg.vector_norm(direction)
111
+ if direction_norm < eps:
112
+ return direction
113
+ return direction / direction_norm * target_norm
114
+
115
+
116
+ def iterative_slerp(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]:
117
+ vectors = [flatten_state(state, keys) for state in states]
118
+ merged = vectors[0]
119
+ accumulated = weights[0]
120
+ for vector, weight in zip(vectors[1:], weights[1:]):
121
+ t = weight / (accumulated + weight)
122
+ merged = slerp_vectors(merged, vector, t)
123
+ accumulated += weight
124
+ return unflatten_state(merged, states[0], keys)
125
+
126
+
127
+ def copy_adapter_files(reference_dir: Path, output_dir: Path) -> None:
128
+ for filename in COPY_FILES:
129
+ src = reference_dir / filename
130
+ if src.exists():
131
+ shutil.copy2(src, output_dir / filename)
132
+
133
+
134
+ def write_metadata(
135
+ output_dir: Path,
136
+ method: str,
137
+ adapters: list[Path],
138
+ weights: list[float],
139
+ base_model: str,
140
+ ) -> None:
141
+ metadata = {
142
+ "method": method,
143
+ "base_model": base_model,
144
+ "adapters": [str(path) for path in adapters],
145
+ "weights": weights,
146
+ }
147
+ (output_dir / "merge_config.json").write_text(json.dumps(metadata, indent=2, ensure_ascii=False), encoding="utf-8")
148
+ readme = [
149
+ "---",
150
+ "library_name: peft",
151
+ "tags:",
152
+ "- peft",
153
+ "- lora",
154
+ "- visual-question-answering",
155
+ f"base_model: {base_model}",
156
+ "---",
157
+ "",
158
+ "# Medical VQA Merged Adapter",
159
+ "",
160
+ f"Merge method: `{method}`",
161
+ "",
162
+ "| Adapter | Weight |",
163
+ "|---|---:|",
164
+ ]
165
+ for path, weight in zip(adapters, weights):
166
+ readme.append(f"| `{path}` | {weight:.4f} |")
167
+ readme.extend(
168
+ [
169
+ "",
170
+ "This adapter was created by merging fine-tuned LoRA adapters without additional training.",
171
+ ]
172
+ )
173
+ (output_dir / "README.md").write_text("\n".join(readme) + "\n", encoding="utf-8")
174
+
175
+
176
+ def main() -> None:
177
+ parser = argparse.ArgumentParser(description="Merge PEFT LoRA adapters via model soup or SLERP.")
178
+ parser.add_argument(
179
+ "--adapter",
180
+ action="append",
181
+ required=True,
182
+ help="Adapter directory, optionally with weight: path=0.4. Repeat for each adapter.",
183
+ )
184
+ parser.add_argument("--output", required=True, help="Output adapter directory.")
185
+ parser.add_argument("--method", choices=["linear", "slerp"], default="linear")
186
+ parser.add_argument("--base-model", default="chaoyinshe/llava-med-v1.5-mistral-7b-hf")
187
+ parser.add_argument(
188
+ "--reference-index",
189
+ type=int,
190
+ default=0,
191
+ help="Adapter index used as source for adapter_config/tokenizer files.",
192
+ )
193
+ args = parser.parse_args()
194
+
195
+ adapters, raw_weights = zip(*(parse_weighted_adapter(raw) for raw in args.adapter))
196
+ adapters = [path.expanduser().resolve() for path in adapters]
197
+ weights = normalize_weights(list(raw_weights))
198
+ output_dir = Path(args.output).expanduser().resolve()
199
+ output_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ if not 0 <= args.reference_index < len(adapters):
202
+ raise ValueError("--reference-index is out of range.")
203
+
204
+ states = [load_adapter_state(path) for path in adapters]
205
+ keys = validate_states(states)
206
+ merged = linear_soup(states, weights, keys) if args.method == "linear" else iterative_slerp(states, weights, keys)
207
+
208
+ save_file(merged, str(output_dir / "adapter_model.safetensors"))
209
+ copy_adapter_files(adapters[args.reference_index], output_dir)
210
+ write_metadata(output_dir, args.method, adapters, weights, args.base_model)
211
+ print(f"Saved merged adapter to: {output_dir}")
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()