| import torch |
| import os |
| import json |
| from transformers import AutoModelForCausalLM |
|
|
|
|
| def extract_and_merge_instruction_residual( |
| instruction_model_dir, |
| base_model_dir, |
| output_dir, |
| ): |
| """ |
| Extract instruction residual in full precision (float32) without any loss. |
| """ |
|
|
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_dir, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True |
| ) |
|
|
| instruction_model = AutoModelForCausalLM.from_pretrained( |
| instruction_model_dir, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True |
| ) |
|
|
| base_state_dict = base_model.state_dict() |
| instruction_state_dict = instruction_model.state_dict() |
|
|
| |
| residual_state_dict = {} |
| for key in base_state_dict: |
| if key in instruction_state_dict: |
| residual_state_dict[key] = (instruction_state_dict[key] - base_state_dict[key]).to(torch.float32) |
| else: |
| print(f"Warning: Key {key} not found in instruction model state dict") |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| adapter_path = os.path.join(output_dir, "instruction_residual_adapter") |
| os.makedirs(adapter_path, exist_ok=True) |
| torch.save(residual_state_dict, os.path.join(adapter_path, "adapter_model.bin")) |
|
|
| |
| adapter_config = { |
| "adapter_type": "instruction_residual", |
| "base_model_name_or_path": base_model_dir, |
| "target_modules": ["all"], |
| "lora_alpha": 1.0, |
| "lora_dropout": 0.0, |
| "task_type": "CAUSAL_LM" |
| } |
|
|
| with open(os.path.join(adapter_path, "adapter_config.json"), "w") as f: |
| json.dump(adapter_config, f, indent=4) |
|
|
| print(f"✅ Full-precision (float32) instruction residual adapter saved to {adapter_path}") |
|
|
|
|
| if __name__ == "__main__": |
| instruction_model_file = "/workspace/meta-llama/Llama-3.2-3B-Instruct" |
| base_model_file = "/workspace/meta-llama/Llama-3.2-3B" |
| residual_output_file = "/workspace/Llama-3.2-3B-Lr" |
|
|
| extract_and_merge_instruction_residual( |
| instruction_model_file, |
| base_model_file, |
| residual_output_file, |
| ) |
|
|