model_tools / shield_norms.py
Naphula's picture
Upload 10 files
f43fd2b verified
# python shield_norms.py "B:\12B\models--p-e-w--Mistral-Nemo-Instruct-2407-heretic-noslop" "C:\Quanter\model_cache\EldritchLabs__Nocturne-Nereid-12B-v1"
import os
import json
import gc
import re
from safetensors.torch import load_file, save_file
import argparse
def get_weight_map(model_path):
index_path = os.path.join(model_path, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path, 'r') as f:
return json.load(f)["weight_map"]
# Fallback for single file models
res = {}
for f in os.listdir(model_path):
if f.endswith(".safetensors"):
# This is slow but accurate for single-file models
tensors = load_file(os.path.join(model_path, f))
for k in tensors.keys():
res[k] = f
return res
def shield_norms(base_path, merged_path):
print("\n[1] Mapping tensors...")
base_map = get_weight_map(base_path)
merged_map = get_weight_map(merged_path)
# Identify all normalization tensors
# Catching: model.norm.weight, input_layernorm.weight, post_attention_layernorm.weight
norm_pattern = re.compile(r".*norm\.weight$")
target_tensors = [t for t in merged_map.keys() if norm_pattern.match(t)]
if not target_tensors:
print(" [!] No normalization tensors found!")
return
print(f" -> Found {len(target_tensors)} normalization tensors to shield.")
# Group by shard to minimize file opening
shards_to_process = {}
for t in target_tensors:
shard = merged_map[t]
if shard not in shards_to_process:
shards_to_process[shard] = []
shards_to_process[shard].append(t)
print(f"\n[2] Processing {len(shards_to_process)} shards...")
for shard_name, tensors in shards_to_process.items():
merged_shard_path = os.path.join(merged_path, shard_name)
backup_shard_path = merged_shard_path + ".old"
print(f" -> Shard: {shard_name}")
# 1. Load the merged shard
merged_tensors = load_file(merged_shard_path, device="cpu")
# 2. Update each target tensor in this shard
for t_name in tensors:
if t_name in base_map:
base_shard_path = os.path.join(base_path, base_map[t_name])
base_data = load_file(base_shard_path, device="cpu")
print(f" Injecting pristine: {t_name}")
merged_tensors[t_name] = base_data[t_name].clone()
del base_data
else:
print(f" [!] Warning: {t_name} not found in base model. Skipping.")
# 3. Atomic Rename Strategy for Windows
if os.path.exists(backup_shard_path):
os.remove(backup_shard_path)
os.rename(merged_shard_path, backup_shard_path)
try:
save_file(merged_tensors, merged_shard_path, metadata={"format": "pt"})
print(f" ✅ Shard saved successfully.")
except Exception as e:
print(f" ❌ Error saving shard: {e}")
os.rename(backup_shard_path, merged_shard_path)
raise e
# 4. Cleanup and release handles
del merged_tensors
gc.collect()
try:
os.remove(backup_shard_path)
except Exception as e:
print(f" [!] Note: Could not delete .old file immediately (OS lock). It will be orphaned: {e}")
def main():
parser = argparse.ArgumentParser(description="Revert all normalization layers to base model values.")
parser.add_argument("base_model", help="Path to the pristine base model")
parser.add_argument("merged_model", help="Path to the merged model with artifacts")
args = parser.parse_args()
print("="*60)
print("🛡️ NORM SHIELD: RE-CALIBRATING OUTPUT MANIFOLD")
print("="*60)
shield_norms(args.base_model, args.merged_model)
print("\n" + "="*60)
print("Done! All normalization layers have been reverted to Base.")
print("This should eliminate the 'napad' / 'derrotó' vector drift.")
print("="*60)
if __name__ == "__main__":
main()