Spaces:
Sleeping
Sleeping
File size: 4,533 Bytes
745f62a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Merge LoRA adapter into base model and save as HF checkpoint.
Bypasses Unsloth and PEFT's module-matching to avoid both:
- Unsloth 2026.4.2 dropping `gemma-4-E4B-it` model name
- PEFT's ValueError on Gemma4ClippableLinear wrappers
Manual merge: delta_W = (B @ A) * (alpha/r), added to base weights.
Output can then be converted to GGUF via llama.cpp's convert_hf_to_gguf.py.
"""
import json
import os
import sys
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["PYTHONIOENCODING"] = "utf-8"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
import torch
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
BASE_MODEL = "google/gemma-4-E4B-it"
ADAPTER_DIR = "./models/checkpoints/final"
MERGED_DIR = "./models/merged_fp16"
print("=" * 60)
print("Manual LoRA merge: base + adapter -> merged FP16")
print(f"Base: {BASE_MODEL}")
print(f"Adapter: {ADAPTER_DIR}")
print(f"Output: {MERGED_DIR}")
print("=" * 60)
if not os.path.exists(os.path.join(ADAPTER_DIR, "adapter_model.safetensors")):
print(f"ABORT: No adapter at {ADAPTER_DIR}")
sys.exit(1)
os.makedirs(MERGED_DIR, exist_ok=True)
# Read adapter config for r/alpha and target modules
with open(os.path.join(ADAPTER_DIR, "adapter_config.json"), "r") as f:
ac = json.load(f)
r = ac["r"]
alpha = ac["lora_alpha"]
scale = alpha / r
print(f"\nAdapter: r={r}, alpha={alpha}, scale={scale:.3f}")
print(f"Target modules: {ac['target_modules']}")
print("\n[1/4] Loading base model in bfloat16 on GPU...")
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="cuda",
local_files_only=True,
)
print(f" Loaded: {type(base).__name__}")
print("\n[2/4] Loading adapter weights...")
adapter_sd = load_file(os.path.join(ADAPTER_DIR, "adapter_model.safetensors"))
print(f" Tensors: {len(adapter_sd)}")
# Pair up lora_A / lora_B by stripping suffixes
pairs = {}
for k in adapter_sd:
if ".lora_A.weight" in k:
base_key = k.replace(".lora_A.weight", "")
pairs.setdefault(base_key, {})["A"] = adapter_sd[k]
elif ".lora_B.weight" in k:
base_key = k.replace(".lora_B.weight", "")
pairs.setdefault(base_key, {})["B"] = adapter_sd[k]
print(f" LoRA pairs: {len(pairs)}")
# Build a name -> module map for the base model. We need to match adapter keys
# like "base_model.model.model.layers.0.self_attn.q_proj" to the actual Linear
# weight in the model. Gemma 4 wraps Linear in Gemma4ClippableLinear.
name_to_module = dict(base.named_modules())
print("\n[3/4] Merging delta into base weights...")
merged = 0
skipped = 0
for key, ab in pairs.items():
if "A" not in ab or "B" not in ab:
skipped += 1
continue
# Strip the "base_model.model." prefix that PEFT adds
target_path = key.replace("base_model.model.", "")
module = name_to_module.get(target_path)
if module is None:
print(f" MISS: {target_path}")
skipped += 1
continue
# Find the actual weight tensor (could be module.weight or module.linear.weight)
if hasattr(module, "weight") and isinstance(module.weight, torch.nn.Parameter):
weight = module.weight
elif hasattr(module, "linear") and hasattr(module.linear, "weight"):
weight = module.linear.weight
else:
print(f" NO_WEIGHT: {target_path} ({type(module).__name__})")
skipped += 1
continue
A = ab["A"].to(weight.device, dtype=torch.float32)
B = ab["B"].to(weight.device, dtype=torch.float32)
delta = (B @ A) * scale
with torch.no_grad():
weight.add_(delta.to(weight.dtype))
merged += 1
print(f" Merged: {merged}, Skipped: {skipped}")
if merged == 0:
print("ABORT: No LoRA pairs were merged")
sys.exit(1)
print(f"\n[4/4] Saving merged model to {MERGED_DIR}...")
base.save_pretrained(MERGED_DIR, safe_serialization=True, max_shard_size="5GB")
try:
processor = AutoProcessor.from_pretrained(BASE_MODEL, local_files_only=True)
processor.save_pretrained(MERGED_DIR)
print(" Processor saved")
except Exception as e:
print(f" Processor save skipped: {e}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, local_files_only=True)
tokenizer.save_pretrained(MERGED_DIR)
print(" Tokenizer saved (fallback)")
print(f"\nDone. Merged model ready at: {MERGED_DIR}")
print("Next: convert_hf_to_gguf.py to produce GGUF")
|