| import os |
| import shutil |
| import torch |
| from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer |
|
|
|
|
| def _load_model_fp32(model_dir: str): |
| |
| try: |
| return AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True, |
| ) |
| except TypeError: |
| return AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True, |
| ) |
|
|
|
|
| def merge_instruction_residual(lr_dir, base_model_dir, output_dir): |
| """ |
| Merge instruction residual into a (possibly vocab-resized) CPT model. |
| |
| If vocab was resized after the residual was computed, we add residual only |
| for the overlapping token rows and keep extra rows (new tokens) unchanged. |
| """ |
|
|
| adapter_file = os.path.join(lr_dir, "adapter_model.bin") |
| if not os.path.exists(adapter_file): |
| raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}") |
|
|
| print("Loading residual adapter...") |
| residual_state_dict = torch.load(adapter_file, map_location="cpu") |
|
|
| print(f"\nMerging residual into base model: {base_model_dir}") |
| base_model = _load_model_fp32(base_model_dir) |
| base_state_dict = base_model.state_dict() |
|
|
| merged_state_dict = {} |
| mismatched = [] |
|
|
| for key, base_tensor in base_state_dict.items(): |
| if key not in residual_state_dict: |
| merged_state_dict[key] = base_tensor |
| continue |
|
|
| res_tensor = residual_state_dict[key] |
|
|
| |
| if base_tensor.shape == res_tensor.shape: |
| merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32) |
| continue |
|
|
| |
| if ( |
| base_tensor.ndim == res_tensor.ndim |
| and base_tensor.ndim >= 1 |
| and base_tensor.shape[1:] == res_tensor.shape[1:] |
| and base_tensor.shape[0] != res_tensor.shape[0] |
| ): |
| n = min(base_tensor.shape[0], res_tensor.shape[0]) |
| out = base_tensor.clone().to(torch.float32) |
| out[:n] += res_tensor[:n].to(torch.float32) |
| merged_state_dict[key] = out |
| mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n)) |
| continue |
|
|
| |
| raise RuntimeError( |
| f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} " |
| f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch." |
| ) |
|
|
| if mismatched: |
| print("\nHandled vocab-resize mismatches by partial add:") |
| for k, bs, rs, n in mismatched[:20]: |
| print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged") |
| if len(mismatched) > 20: |
| print(f" ... and {len(mismatched) - 20} more") |
|
|
| |
| base_model.load_state_dict(merged_state_dict, strict=True) |
|
|
| |
| base_model = base_model.to(torch.bfloat16) |
| os.makedirs(output_dir, exist_ok=True) |
| base_model.save_pretrained(output_dir, safe_serialization=True) |
|
|
| |
| base_config = AutoConfig.from_pretrained(base_model_dir) |
| base_config.save_pretrained(output_dir) |
|
|
| |
| try: |
| tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) |
| tok.save_pretrained(output_dir) |
| except Exception: |
| |
| for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: |
| src_path = os.path.join(base_model_dir, file_name) |
| dst_path = os.path.join(output_dir, file_name) |
| if os.path.exists(src_path): |
| shutil.copyfile(src_path, dst_path) |
|
|
| print(f"\n✅ Merge complete.") |
| print(f"🧠 fp32 math → saved bf16 at: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| lr_file = "/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter" |
| base_model_file = "/workspace/v126rc_exp3/F_r10000/checkpoint-31" |
| output_root = "/workspace/v126rc_exp3/F_r10000/checkpoint-31/residued" |
|
|
| merge_instruction_residual(lr_file, base_model_file, output_root) |
|
|