PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 1,363 Bytes
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python3
"""Merge a PEFT LoRA adapter into the base model for faster inference/deployment."""
import argparse
from pathlib import Path

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--base_model", required=True)
    p.add_argument("--adapter", required=True)
    p.add_argument("--output_dir", required=True)
    p.add_argument("--push_to_hub", action="store_true")
    p.add_argument("--hub_model_id", default=None)
    args = p.parse_args()

    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    model = PeftModel.from_pretrained(model, args.adapter)
    model = model.merge_and_unload()
    tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True)
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    model.save_pretrained(args.output_dir, safe_serialization=True)
    tokenizer.save_pretrained(args.output_dir)
    if args.push_to_hub:
        repo = args.hub_model_id or args.output_dir
        model.push_to_hub(repo)
        tokenizer.push_to_hub(repo)
        print(f"Pushed merged model to https://huggingface.co/{repo}")


if __name__ == "__main__":
    main()