| |
| """Merge a PEFT LoRA adapter into the base model for faster inference/deployment.""" |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base_model", required=True) |
| p.add_argument("--adapter", required=True) |
| p.add_argument("--output_dir", required=True) |
| p.add_argument("--push_to_hub", action="store_true") |
| p.add_argument("--hub_model_id", default=None) |
| args = p.parse_args() |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| model = PeftModel.from_pretrained(model, args.adapter) |
| model = model.merge_and_unload() |
| tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True) |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(args.output_dir, safe_serialization=True) |
| tokenizer.save_pretrained(args.output_dir) |
| if args.push_to_hub: |
| repo = args.hub_model_id or args.output_dir |
| model.push_to_hub(repo) |
| tokenizer.push_to_hub(repo) |
| print(f"Pushed merged model to https://huggingface.co/{repo}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|