| """ |
| Step 1: Merge DFlash-LoRA adapter into base model. |
| Usage: |
| conda activate sglang |
| python3 merge_lora.py |
| python3 merge_lora.py --ckpt epoch_2_step_15000 # 测其他 checkpoint |
| """ |
| import argparse |
| import os |
|
|
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| BASE_MODEL = "/workspace/models/Qwen3-8B" |
| OUTPUT_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu-v2" |
| MERGE_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu-v2-merged" |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", default="epoch_0_step_3000", |
| help="Checkpoint folder name under OUTPUT_ROOT") |
| p.add_argument("--merged-path", default=MERGE_ROOT, |
| help="Where to save the merged model") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| adapter_path = os.path.join(OUTPUT_ROOT, args.ckpt) |
| merged_path = args.merged_path |
|
|
| if os.path.exists(merged_path): |
| print(f"[skip] Merged model already exists: {merged_path}") |
| return |
|
|
| assert os.path.isdir(adapter_path), f"Adapter not found: {adapter_path}" |
|
|
| print(f"Base model : {BASE_MODEL}") |
| print(f"Adapter : {adapter_path}") |
| print(f"Output : {merged_path}") |
| print() |
|
|
| print("[1/4] Loading base model to CPU ...") |
| model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu", |
| ) |
|
|
| print("[2/4] Loading LoRA adapter ...") |
| model = PeftModel.from_pretrained(model, adapter_path) |
|
|
| print("[3/4] Merging weights ...") |
| model = model.merge_and_unload() |
|
|
| print("[4/4] Saving merged model ...") |
| os.makedirs(merged_path, exist_ok=True) |
| model.save_pretrained(merged_path, safe_serialization=True) |
| AutoTokenizer.from_pretrained(BASE_MODEL).save_pretrained(merged_path) |
|
|
| print(f"\nDone. Merged model saved to: {merged_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|