Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """Merge PEFT/LoRA adapters with linear model soup or SLERP. | |
| This is intended for adapters that share the same base model and LoRA config, | |
| for example B2, DPO, and PPO adapters trained from the same LLaVA-Med base. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| from safetensors.torch import load_file, save_file | |
| COPY_FILES = [ | |
| "adapter_config.json", | |
| "tokenizer_config.json", | |
| "tokenizer.json", | |
| "processor_config.json", | |
| "chat_template.jinja", | |
| ] | |
| def parse_weighted_adapter(raw: str) -> tuple[Path, float]: | |
| if "=" not in raw: | |
| return Path(raw), 1.0 | |
| path, weight = raw.rsplit("=", 1) | |
| return Path(path), float(weight) | |
| def normalize_weights(weights: list[float]) -> list[float]: | |
| total = sum(weights) | |
| if not math.isfinite(total) or total <= 0: | |
| raise ValueError("Adapter weights must sum to a positive finite value.") | |
| return [weight / total for weight in weights] | |
| def adapter_model_path(adapter_dir: Path) -> Path: | |
| path = adapter_dir / "adapter_model.safetensors" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Missing adapter weights: {path}") | |
| return path | |
| def load_adapter_state(adapter_dir: Path) -> dict[str, torch.Tensor]: | |
| return load_file(str(adapter_model_path(adapter_dir)), device="cpu") | |
| def validate_states(states: list[dict[str, torch.Tensor]]) -> list[str]: | |
| keys = sorted(states[0].keys()) | |
| expected = set(keys) | |
| for idx, state in enumerate(states[1:], start=2): | |
| if set(state.keys()) != expected: | |
| missing = sorted(expected - set(state.keys()))[:5] | |
| extra = sorted(set(state.keys()) - expected)[:5] | |
| raise ValueError(f"Adapter {idx} has different keys. Missing={missing}, extra={extra}") | |
| for key in keys: | |
| if state[key].shape != states[0][key].shape: | |
| raise ValueError(f"Shape mismatch for {key}: {state[key].shape} != {states[0][key].shape}") | |
| return keys | |
| def linear_soup(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]: | |
| merged: dict[str, torch.Tensor] = {} | |
| for key in keys: | |
| ref = states[0][key] | |
| tensor = torch.zeros_like(ref, dtype=torch.float32) | |
| for state, weight in zip(states, weights): | |
| tensor += state[key].float() * weight | |
| merged[key] = tensor.to(dtype=ref.dtype) | |
| return merged | |
| def flatten_state(state: dict[str, torch.Tensor], keys: list[str]) -> torch.Tensor: | |
| return torch.cat([state[key].float().reshape(-1) for key in keys]) | |
| def unflatten_state(vector: torch.Tensor, reference: dict[str, torch.Tensor], keys: list[str]) -> dict[str, torch.Tensor]: | |
| merged: dict[str, torch.Tensor] = {} | |
| offset = 0 | |
| for key in keys: | |
| ref = reference[key] | |
| size = ref.numel() | |
| merged[key] = vector[offset : offset + size].reshape(ref.shape).to(dtype=ref.dtype) | |
| offset += size | |
| return merged | |
| def slerp_vectors(a: torch.Tensor, b: torch.Tensor, t: float, eps: float = 1e-8) -> torch.Tensor: | |
| a_norm = torch.linalg.vector_norm(a) | |
| b_norm = torch.linalg.vector_norm(b) | |
| if a_norm < eps or b_norm < eps: | |
| return (1.0 - t) * a + t * b | |
| a_unit = a / a_norm | |
| b_unit = b / b_norm | |
| dot = torch.clamp(torch.dot(a_unit, b_unit), -1.0, 1.0) | |
| omega = torch.acos(dot) | |
| sin_omega = torch.sin(omega) | |
| if torch.abs(sin_omega) < eps: | |
| return (1.0 - t) * a + t * b | |
| direction = (torch.sin((1.0 - t) * omega) / sin_omega) * a + (torch.sin(t * omega) / sin_omega) * b | |
| target_norm = (1.0 - t) * a_norm + t * b_norm | |
| direction_norm = torch.linalg.vector_norm(direction) | |
| if direction_norm < eps: | |
| return direction | |
| return direction / direction_norm * target_norm | |
| def iterative_slerp(states: list[dict[str, torch.Tensor]], weights: list[float], keys: list[str]) -> dict[str, torch.Tensor]: | |
| vectors = [flatten_state(state, keys) for state in states] | |
| merged = vectors[0] | |
| accumulated = weights[0] | |
| for vector, weight in zip(vectors[1:], weights[1:]): | |
| t = weight / (accumulated + weight) | |
| merged = slerp_vectors(merged, vector, t) | |
| accumulated += weight | |
| return unflatten_state(merged, states[0], keys) | |
| def copy_adapter_files(reference_dir: Path, output_dir: Path) -> None: | |
| for filename in COPY_FILES: | |
| src = reference_dir / filename | |
| if src.exists(): | |
| shutil.copy2(src, output_dir / filename) | |
| def write_metadata( | |
| output_dir: Path, | |
| method: str, | |
| adapters: list[Path], | |
| weights: list[float], | |
| base_model: str, | |
| ) -> None: | |
| metadata = { | |
| "method": method, | |
| "base_model": base_model, | |
| "adapters": [str(path) for path in adapters], | |
| "weights": weights, | |
| } | |
| (output_dir / "merge_config.json").write_text(json.dumps(metadata, indent=2, ensure_ascii=False), encoding="utf-8") | |
| readme = [ | |
| "---", | |
| "library_name: peft", | |
| "tags:", | |
| "- peft", | |
| "- lora", | |
| "- visual-question-answering", | |
| f"base_model: {base_model}", | |
| "---", | |
| "", | |
| "# Medical VQA Merged Adapter", | |
| "", | |
| f"Merge method: `{method}`", | |
| "", | |
| "| Adapter | Weight |", | |
| "|---|---:|", | |
| ] | |
| for path, weight in zip(adapters, weights): | |
| readme.append(f"| `{path}` | {weight:.4f} |") | |
| readme.extend( | |
| [ | |
| "", | |
| "This adapter was created by merging fine-tuned LoRA adapters without additional training.", | |
| ] | |
| ) | |
| (output_dir / "README.md").write_text("\n".join(readme) + "\n", encoding="utf-8") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Merge PEFT LoRA adapters via model soup or SLERP.") | |
| parser.add_argument( | |
| "--adapter", | |
| action="append", | |
| required=True, | |
| help="Adapter directory, optionally with weight: path=0.4. Repeat for each adapter.", | |
| ) | |
| parser.add_argument("--output", required=True, help="Output adapter directory.") | |
| parser.add_argument("--method", choices=["linear", "slerp"], default="linear") | |
| parser.add_argument("--base-model", default="chaoyinshe/llava-med-v1.5-mistral-7b-hf") | |
| parser.add_argument( | |
| "--reference-index", | |
| type=int, | |
| default=0, | |
| help="Adapter index used as source for adapter_config/tokenizer files.", | |
| ) | |
| args = parser.parse_args() | |
| adapters, raw_weights = zip(*(parse_weighted_adapter(raw) for raw in args.adapter)) | |
| adapters = [path.expanduser().resolve() for path in adapters] | |
| weights = normalize_weights(list(raw_weights)) | |
| output_dir = Path(args.output).expanduser().resolve() | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if not 0 <= args.reference_index < len(adapters): | |
| raise ValueError("--reference-index is out of range.") | |
| states = [load_adapter_state(path) for path in adapters] | |
| keys = validate_states(states) | |
| merged = linear_soup(states, weights, keys) if args.method == "linear" else iterative_slerp(states, weights, keys) | |
| save_file(merged, str(output_dir / "adapter_model.safetensors")) | |
| copy_adapter_files(adapters[args.reference_index], output_dir) | |
| write_metadata(output_dir, args.method, adapters, weights, args.base_model) | |
| print(f"Saved merged adapter to: {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |