Spaces:
Running
Running
| import os | |
| import json | |
| import gc | |
| import shutil | |
| from safetensors.torch import load_file, save_file | |
| import argparse | |
| def get_weight_map(model_path): | |
| index_path = os.path.join(model_path, "model.safetensors.index.json") | |
| if os.path.exists(index_path): | |
| with open(index_path, 'r') as f: | |
| return json.load(f)["weight_map"] | |
| for f in os.listdir(model_path): | |
| if f.endswith(".safetensors"): | |
| return {"model.embed_tokens.weight": f, "lm_head.weight": f} | |
| return {} | |
| def swap_tensor(base_path, merged_path, tensor_name): | |
| base_map = get_weight_map(base_path) | |
| merged_map = get_weight_map(merged_path) | |
| if tensor_name not in base_map or tensor_name not in merged_map: | |
| return | |
| base_shard = os.path.join(base_path, base_map[tensor_name]) | |
| merged_shard = os.path.join(merged_path, merged_map[tensor_name]) | |
| # Load pristine tensor | |
| base_tensors = load_file(base_shard, device="cpu") | |
| pristine_tensor = base_tensors[tensor_name].clone() | |
| del base_tensors | |
| gc.collect() | |
| # Load merged shards | |
| merged_tensors = load_file(merged_shard, device="cpu") | |
| merged_tensors[tensor_name] = pristine_tensor | |
| # ATOMIC RENAME STRATEGY (The only way to beat Windows 1224) | |
| backup_shard = merged_shard + ".old" | |
| os.rename(merged_shard, backup_shard) # Move current file to side | |
| try: | |
| save_file(merged_tensors, merged_shard, metadata={"format": "pt"}) | |
| except Exception as e: | |
| os.rename(backup_shard, merged_shard) # Restore if fail | |
| raise e | |
| # Cleanup | |
| del merged_tensors | |
| del pristine_tensor | |
| gc.collect() | |
| os.remove(backup_shard) # Now delete the old mapped file | |
| print(f" ✅ Successfully shielded {tensor_name}!") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("base_model") | |
| parser.add_argument("merged_model") | |
| args = parser.parse_args() | |
| swap_tensor(args.base_model, args.merged_model, "model.embed_tokens.weight") | |
| swap_tensor(args.base_model, args.merged_model, "lm_head.weight") | |
| if __name__ == "__main__": | |
| main() |