PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /merge_adapter.py
nraptisss's picture
Add RTX 6000 Ada QLoRA training and evaluation repo
d9ba941 verified
#!/usr/bin/env python3
"""Merge a PEFT LoRA adapter into the base model for faster inference/deployment."""
import argparse
from pathlib import Path
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def main():
p = argparse.ArgumentParser()
p.add_argument("--base_model", required=True)
p.add_argument("--adapter", required=True)
p.add_argument("--output_dir", required=True)
p.add_argument("--push_to_hub", action="store_true")
p.add_argument("--hub_model_id", default=None)
args = p.parse_args()
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(model, args.adapter)
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(args.output_dir, safe_serialization=True)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo = args.hub_model_id or args.output_dir
model.push_to_hub(repo)
tokenizer.push_to_hub(repo)
print(f"Pushed merged model to https://huggingface.co/{repo}")
if __name__ == "__main__":
main()