| |
| """ |
| Codette LoRA Adapter Merger |
| ============================== |
| |
| Merge one or more LoRA adapters into the base model to produce |
| a standalone fine-tuned model. Adapters are applied and merged |
| sequentially in the order specified. |
| |
| Usage: |
| python -m training.merge_adapters \ |
| --base-model meta-llama/Llama-3.1-8B-Instruct \ |
| --adapters adapters/newton/final adapters/davinci/final \ |
| --output merged_model |
| |
| python -m training.merge_adapters \ |
| --base-model meta-llama/Llama-3.1-8B-Instruct \ |
| --adapters adapters/rcxi/final \ |
| --output merged_model \ |
| --dtype bfloat16 |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| def setup_logging(output_dir: str) -> logging.Logger: |
| """Configure logging for the merge process. |
| |
| Args: |
| output_dir: Directory for log output. |
| |
| Returns: |
| Configured logger instance. |
| """ |
| log_dir = Path(output_dir) |
| log_dir.mkdir(parents=True, exist_ok=True) |
|
|
| logger = logging.getLogger("codette.merge") |
| logger.setLevel(logging.DEBUG) |
| logger.handlers.clear() |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| fh = logging.FileHandler( |
| str(log_dir / f"merge_{timestamp}.log"), encoding="utf-8" |
| ) |
| fh.setLevel(logging.DEBUG) |
| fh.setFormatter(logging.Formatter( |
| "%(asctime)s | %(levelname)-8s | %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| )) |
| logger.addHandler(fh) |
|
|
| ch = logging.StreamHandler(sys.stdout) |
| ch.setLevel(logging.INFO) |
| ch.setFormatter(logging.Formatter( |
| "%(asctime)s | %(levelname)-8s | %(message)s", |
| datefmt="%H:%M:%S", |
| )) |
| logger.addHandler(ch) |
|
|
| return logger |
|
|
|
|
| def resolve_dtype(dtype_str: str) -> torch.dtype: |
| """Convert a string dtype to a torch dtype. |
| |
| Args: |
| dtype_str: One of 'float32', 'float16', 'bfloat16'. |
| |
| Returns: |
| Corresponding torch.dtype. |
| |
| Raises: |
| ValueError: If the string is not a recognized dtype. |
| """ |
| dtype_map = { |
| "float32": torch.float32, |
| "fp32": torch.float32, |
| "float16": torch.float16, |
| "fp16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "bf16": torch.bfloat16, |
| } |
| if dtype_str not in dtype_map: |
| raise ValueError( |
| f"Unknown dtype: {dtype_str}. " |
| f"Choose from: {list(dtype_map.keys())}" |
| ) |
| return dtype_map[dtype_str] |
|
|
|
|
| def validate_adapter_paths(adapter_paths: list[str], logger: logging.Logger) -> None: |
| """Validate that all adapter paths exist and contain expected files. |
| |
| Args: |
| adapter_paths: List of adapter directory paths. |
| logger: Logger instance. |
| |
| Raises: |
| FileNotFoundError: If any adapter path is invalid. |
| """ |
| for adapter_path in adapter_paths: |
| path = Path(adapter_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Adapter directory not found: {adapter_path}") |
|
|
| |
| config_file = path / "adapter_config.json" |
| if not config_file.exists(): |
| raise FileNotFoundError( |
| f"No adapter_config.json found in {adapter_path}. " |
| f"Is this a valid PEFT adapter directory?" |
| ) |
|
|
| logger.info(f"Validated adapter: {adapter_path}") |
|
|
|
|
| def load_base_model( |
| model_name: str, |
| dtype: torch.dtype, |
| device_map: str, |
| logger: logging.Logger, |
| ): |
| """Load the base model for merging. |
| |
| Args: |
| model_name: HuggingFace model identifier. |
| dtype: Torch dtype for model weights. |
| device_map: Device map strategy. |
| logger: Logger instance. |
| |
| Returns: |
| Tuple of (model, tokenizer). |
| """ |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| logger.info(f"Loading base model: {model_name}") |
| logger.info(f" dtype: {dtype}, device_map: {device_map}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, trust_remote_code=True |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=dtype, |
| device_map=device_map, |
| trust_remote_code=True, |
| ) |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| logger.info(f"Base model loaded: {param_count:,} parameters") |
|
|
| return model, tokenizer |
|
|
|
|
| def apply_and_merge_adapter( |
| model, |
| adapter_path: str, |
| adapter_index: int, |
| total_adapters: int, |
| logger: logging.Logger, |
| ): |
| """Apply a single LoRA adapter and merge it into the base weights. |
| |
| Uses PEFT's load_adapter, set_adapter, and merge_and_unload |
| to apply LoRA weights directly into the base model. |
| |
| Args: |
| model: The current model (base or previously merged). |
| adapter_path: Path to the PEFT adapter directory. |
| adapter_index: Index of this adapter (for logging). |
| total_adapters: Total number of adapters to merge. |
| logger: Logger instance. |
| |
| Returns: |
| Model with the adapter merged in. |
| """ |
| from peft import PeftModel |
|
|
| adapter_name = Path(adapter_path).parent.name |
| logger.info( |
| f"[{adapter_index}/{total_adapters}] " |
| f"Applying adapter: {adapter_name} ({adapter_path})" |
| ) |
|
|
| |
| config_path = Path(adapter_path) / "adapter_config.json" |
| with open(config_path, "r", encoding="utf-8") as f: |
| adapter_config = json.load(f) |
|
|
| lora_rank = adapter_config.get("r", "unknown") |
| lora_alpha = adapter_config.get("lora_alpha", "unknown") |
| target_modules = adapter_config.get("target_modules", []) |
|
|
| logger.info( |
| f" LoRA config: rank={lora_rank}, alpha={lora_alpha}, " |
| f"modules={target_modules}" |
| ) |
|
|
| |
| if adapter_index == 1: |
| |
| model = PeftModel.from_pretrained( |
| model, |
| adapter_path, |
| is_trainable=False, |
| ) |
| else: |
| |
| adapter_id = f"adapter_{adapter_index}" |
| model.load_adapter(adapter_path, adapter_name=adapter_id) |
| model.set_adapter(adapter_id) |
|
|
| |
| logger.info(f" Merging adapter weights into base model...") |
| model = model.merge_and_unload() |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| logger.info(f" Merged successfully. Model params: {param_count:,}") |
|
|
| return model |
|
|
|
|
| def save_merged_model( |
| model, |
| tokenizer, |
| output_dir: str, |
| logger: logging.Logger, |
| ) -> None: |
| """Save the fully merged model and tokenizer. |
| |
| Args: |
| model: The merged model. |
| tokenizer: The tokenizer. |
| output_dir: Directory to save the model. |
| logger: Logger instance. |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| logger.info(f"Saving merged model to: {output_dir}") |
|
|
| model.save_pretrained(output_dir, safe_serialization=True) |
| tokenizer.save_pretrained(output_dir) |
|
|
| |
| total_size = 0 |
| for f in output_path.glob("*.safetensors"): |
| total_size += f.stat().st_size |
| for f in output_path.glob("*.bin"): |
| total_size += f.stat().st_size |
|
|
| size_gb = total_size / (1024 ** 3) |
| logger.info(f"Model saved: {size_gb:.2f} GB") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| """Parse command-line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Merge LoRA adapters into the base model", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument( |
| "--base-model", |
| type=str, |
| default="meta-llama/Llama-3.1-8B-Instruct", |
| help="Base model to merge adapters into", |
| ) |
| parser.add_argument( |
| "--adapters", |
| nargs="+", |
| required=True, |
| help="Paths to PEFT adapter directories (applied in order)", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| required=True, |
| help="Output directory for merged model", |
| ) |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], |
| help="Model dtype for merging", |
| ) |
| parser.add_argument( |
| "--device-map", |
| type=str, |
| default="auto", |
| help="Device map strategy (auto, cpu, cuda:0, etc.)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| """Main entry point for adapter merging.""" |
| args = parse_args() |
|
|
| logger = setup_logging(args.output) |
| logger.info("=== Codette LoRA Adapter Merger ===") |
| logger.info(f"Base model: {args.base_model}") |
| logger.info(f"Adapters to merge ({len(args.adapters)}): {args.adapters}") |
| logger.info(f"Output: {args.output}") |
| logger.info(f"dtype: {args.dtype}") |
|
|
| dtype = resolve_dtype(args.dtype) |
|
|
| |
| try: |
| validate_adapter_paths(args.adapters, logger) |
| except FileNotFoundError as e: |
| logger.error(str(e)) |
| sys.exit(1) |
|
|
| start_time = time.time() |
|
|
| try: |
| |
| model, tokenizer = load_base_model( |
| args.base_model, dtype, args.device_map, logger |
| ) |
|
|
| |
| for i, adapter_path in enumerate(args.adapters, 1): |
| model = apply_and_merge_adapter( |
| model=model, |
| adapter_path=adapter_path, |
| adapter_index=i, |
| total_adapters=len(args.adapters), |
| logger=logger, |
| ) |
|
|
| |
| save_merged_model(model, tokenizer, args.output, logger) |
|
|
| elapsed = time.time() - start_time |
|
|
| |
| metadata = { |
| "base_model": args.base_model, |
| "adapters_merged": args.adapters, |
| "adapter_count": len(args.adapters), |
| "dtype": args.dtype, |
| "merge_time_seconds": elapsed, |
| "timestamp": datetime.now().isoformat(), |
| } |
| metadata_path = Path(args.output) / "merge_metadata.json" |
| with open(metadata_path, "w", encoding="utf-8") as f: |
| json.dump(metadata, f, indent=2) |
|
|
| logger.info(f"=== Merge complete in {elapsed:.1f}s ===") |
| logger.info(f"Merged model saved to: {args.output}") |
|
|
| except Exception as e: |
| logger.error(f"Merge failed: {e}", exc_info=True) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|