import os import json import gc import shutil from safetensors.torch import load_file, save_file import argparse def get_weight_map(model_path): index_path = os.path.join(model_path, "model.safetensors.index.json") if os.path.exists(index_path): with open(index_path, 'r') as f: return json.load(f)["weight_map"] for f in os.listdir(model_path): if f.endswith(".safetensors"): return {"model.embed_tokens.weight": f, "lm_head.weight": f} return {} def swap_tensor(base_path, merged_path, tensor_name): base_map = get_weight_map(base_path) merged_map = get_weight_map(merged_path) if tensor_name not in base_map or tensor_name not in merged_map: return base_shard = os.path.join(base_path, base_map[tensor_name]) merged_shard = os.path.join(merged_path, merged_map[tensor_name]) # Load pristine tensor base_tensors = load_file(base_shard, device="cpu") pristine_tensor = base_tensors[tensor_name].clone() del base_tensors gc.collect() # Load merged shards merged_tensors = load_file(merged_shard, device="cpu") merged_tensors[tensor_name] = pristine_tensor # ATOMIC RENAME STRATEGY (The only way to beat Windows 1224) backup_shard = merged_shard + ".old" os.rename(merged_shard, backup_shard) # Move current file to side try: save_file(merged_tensors, merged_shard, metadata={"format": "pt"}) except Exception as e: os.rename(backup_shard, merged_shard) # Restore if fail raise e # Cleanup del merged_tensors del pristine_tensor gc.collect() os.remove(backup_shard) # Now delete the old mapped file print(f" ✅ Successfully shielded {tensor_name}!") def main(): parser = argparse.ArgumentParser() parser.add_argument("base_model") parser.add_argument("merged_model") args = parser.parse_args() swap_tensor(args.base_model, args.merged_model, "model.embed_tokens.weight") swap_tensor(args.base_model, args.merged_model, "lm_head.weight") if __name__ == "__main__": main()