#!/usr/bin/env python3 """Merge a PEFT LoRA adapter into the base model for faster inference/deployment.""" import argparse from pathlib import Path import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def main(): p = argparse.ArgumentParser() p.add_argument("--base_model", required=True) p.add_argument("--adapter", required=True) p.add_argument("--output_dir", required=True) p.add_argument("--push_to_hub", action="store_true") p.add_argument("--hub_model_id", default=None) args = p.parse_args() model = AutoModelForCausalLM.from_pretrained( args.base_model, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model = PeftModel.from_pretrained(model, args.adapter) model = model.merge_and_unload() tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True) Path(args.output_dir).mkdir(parents=True, exist_ok=True) model.save_pretrained(args.output_dir, safe_serialization=True) tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo = args.hub_model_id or args.output_dir model.push_to_hub(repo) tokenizer.push_to_hub(repo) print(f"Pushed merged model to https://huggingface.co/{repo}") if __name__ == "__main__": main()