File size: 2,512 Bytes
6c999c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 导入所需的库
from transformers import AutoModelForCausalLM  # 用于加载预训练的语言模型
from peft import LoraConfig, get_peft_model, PeftModel  # 用于处理LoRA权重
import argparse  # 用于解析命令行参数
import shutil  # 用于文件操作,如复制
import os  # 用于文件路径操作
import torch  # 用于深度学习操作

def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数
    parser.add_argument("--base_model_path", type=str, required=True, 
                        help="Path to pretrained model or model identifier from huggingface.co/models")
    parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
    parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
    parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'], 
                        default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
    # 解析命令行参数
    args = parser.parse_args()

    name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
    # 加载基座模型
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model_path, device_map='cpu', 
        trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
    )
    # 在基座模型的基础上加载 adapter 权重
    model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
    # 融合模型和 adapter
    model = model.merge_and_unload()
    # 保存融合后的模型权重
    model.save_pretrained(args.output_path, safe_serialization=False)

    # Copy tokenizer, config and other non-weight files from base model
    # Skip model weight files (.safetensors, .bin, .pt) and index files
    _SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
    _SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}

    for fname in os.listdir(args.base_model_path):
        src = os.path.join(args.base_model_path, fname)
        if not os.path.isfile(src):
            continue
        if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
            continue
        dst = os.path.join(args.output_path, fname)
        if not os.path.exists(dst):
            shutil.copy(src, dst)
            print(f'Copied {fname}')

    print(f'Merged model weight is saved to {args.output_path}')
    
if __name__ == "__main__":
    main()