model_tools / shield_embeddings.py
Naphula's picture
Upload 10 files
f43fd2b verified
import os
import json
import gc
import shutil
from safetensors.torch import load_file, save_file
import argparse
def get_weight_map(model_path):
index_path = os.path.join(model_path, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path, 'r') as f:
return json.load(f)["weight_map"]
for f in os.listdir(model_path):
if f.endswith(".safetensors"):
return {"model.embed_tokens.weight": f, "lm_head.weight": f}
return {}
def swap_tensor(base_path, merged_path, tensor_name):
base_map = get_weight_map(base_path)
merged_map = get_weight_map(merged_path)
if tensor_name not in base_map or tensor_name not in merged_map:
return
base_shard = os.path.join(base_path, base_map[tensor_name])
merged_shard = os.path.join(merged_path, merged_map[tensor_name])
# Load pristine tensor
base_tensors = load_file(base_shard, device="cpu")
pristine_tensor = base_tensors[tensor_name].clone()
del base_tensors
gc.collect()
# Load merged shards
merged_tensors = load_file(merged_shard, device="cpu")
merged_tensors[tensor_name] = pristine_tensor
# ATOMIC RENAME STRATEGY (The only way to beat Windows 1224)
backup_shard = merged_shard + ".old"
os.rename(merged_shard, backup_shard) # Move current file to side
try:
save_file(merged_tensors, merged_shard, metadata={"format": "pt"})
except Exception as e:
os.rename(backup_shard, merged_shard) # Restore if fail
raise e
# Cleanup
del merged_tensors
del pristine_tensor
gc.collect()
os.remove(backup_shard) # Now delete the old mapped file
print(f" ✅ Successfully shielded {tensor_name}!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("base_model")
parser.add_argument("merged_model")
args = parser.parse_args()
swap_tensor(args.base_model, args.merged_model, "model.embed_tokens.weight")
swap_tensor(args.base_model, args.merged_model, "lm_head.weight")
if __name__ == "__main__":
main()