Spaces:
Sleeping
Sleeping
| """Merge LoRA adapter into base model and save as HF checkpoint. | |
| Bypasses Unsloth and PEFT's module-matching to avoid both: | |
| - Unsloth 2026.4.2 dropping `gemma-4-E4B-it` model name | |
| - PEFT's ValueError on Gemma4ClippableLinear wrappers | |
| Manual merge: delta_W = (B @ A) * (alpha/r), added to base weights. | |
| Output can then be converted to GGUF via llama.cpp's convert_hf_to_gguf.py. | |
| """ | |
| import json | |
| import os | |
| import sys | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| os.environ["PYTHONIOENCODING"] = "utf-8" | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| import torch | |
| from safetensors.torch import load_file | |
| from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer | |
| BASE_MODEL = "google/gemma-4-E4B-it" | |
| ADAPTER_DIR = "./models/checkpoints/final" | |
| MERGED_DIR = "./models/merged_fp16" | |
| print("=" * 60) | |
| print("Manual LoRA merge: base + adapter -> merged FP16") | |
| print(f"Base: {BASE_MODEL}") | |
| print(f"Adapter: {ADAPTER_DIR}") | |
| print(f"Output: {MERGED_DIR}") | |
| print("=" * 60) | |
| if not os.path.exists(os.path.join(ADAPTER_DIR, "adapter_model.safetensors")): | |
| print(f"ABORT: No adapter at {ADAPTER_DIR}") | |
| sys.exit(1) | |
| os.makedirs(MERGED_DIR, exist_ok=True) | |
| # Read adapter config for r/alpha and target modules | |
| with open(os.path.join(ADAPTER_DIR, "adapter_config.json"), "r") as f: | |
| ac = json.load(f) | |
| r = ac["r"] | |
| alpha = ac["lora_alpha"] | |
| scale = alpha / r | |
| print(f"\nAdapter: r={r}, alpha={alpha}, scale={scale:.3f}") | |
| print(f"Target modules: {ac['target_modules']}") | |
| print("\n[1/4] Loading base model in bfloat16 on GPU...") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| device_map="cuda", | |
| local_files_only=True, | |
| ) | |
| print(f" Loaded: {type(base).__name__}") | |
| print("\n[2/4] Loading adapter weights...") | |
| adapter_sd = load_file(os.path.join(ADAPTER_DIR, "adapter_model.safetensors")) | |
| print(f" Tensors: {len(adapter_sd)}") | |
| # Pair up lora_A / lora_B by stripping suffixes | |
| pairs = {} | |
| for k in adapter_sd: | |
| if ".lora_A.weight" in k: | |
| base_key = k.replace(".lora_A.weight", "") | |
| pairs.setdefault(base_key, {})["A"] = adapter_sd[k] | |
| elif ".lora_B.weight" in k: | |
| base_key = k.replace(".lora_B.weight", "") | |
| pairs.setdefault(base_key, {})["B"] = adapter_sd[k] | |
| print(f" LoRA pairs: {len(pairs)}") | |
| # Build a name -> module map for the base model. We need to match adapter keys | |
| # like "base_model.model.model.layers.0.self_attn.q_proj" to the actual Linear | |
| # weight in the model. Gemma 4 wraps Linear in Gemma4ClippableLinear. | |
| name_to_module = dict(base.named_modules()) | |
| print("\n[3/4] Merging delta into base weights...") | |
| merged = 0 | |
| skipped = 0 | |
| for key, ab in pairs.items(): | |
| if "A" not in ab or "B" not in ab: | |
| skipped += 1 | |
| continue | |
| # Strip the "base_model.model." prefix that PEFT adds | |
| target_path = key.replace("base_model.model.", "") | |
| module = name_to_module.get(target_path) | |
| if module is None: | |
| print(f" MISS: {target_path}") | |
| skipped += 1 | |
| continue | |
| # Find the actual weight tensor (could be module.weight or module.linear.weight) | |
| if hasattr(module, "weight") and isinstance(module.weight, torch.nn.Parameter): | |
| weight = module.weight | |
| elif hasattr(module, "linear") and hasattr(module.linear, "weight"): | |
| weight = module.linear.weight | |
| else: | |
| print(f" NO_WEIGHT: {target_path} ({type(module).__name__})") | |
| skipped += 1 | |
| continue | |
| A = ab["A"].to(weight.device, dtype=torch.float32) | |
| B = ab["B"].to(weight.device, dtype=torch.float32) | |
| delta = (B @ A) * scale | |
| with torch.no_grad(): | |
| weight.add_(delta.to(weight.dtype)) | |
| merged += 1 | |
| print(f" Merged: {merged}, Skipped: {skipped}") | |
| if merged == 0: | |
| print("ABORT: No LoRA pairs were merged") | |
| sys.exit(1) | |
| print(f"\n[4/4] Saving merged model to {MERGED_DIR}...") | |
| base.save_pretrained(MERGED_DIR, safe_serialization=True, max_shard_size="5GB") | |
| try: | |
| processor = AutoProcessor.from_pretrained(BASE_MODEL, local_files_only=True) | |
| processor.save_pretrained(MERGED_DIR) | |
| print(" Processor saved") | |
| except Exception as e: | |
| print(f" Processor save skipped: {e}") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, local_files_only=True) | |
| tokenizer.save_pretrained(MERGED_DIR) | |
| print(" Tokenizer saved (fallback)") | |
| print(f"\nDone. Merged model ready at: {MERGED_DIR}") | |
| print("Next: convert_hf_to_gguf.py to produce GGUF") | |