# python shield_norms.py "B:\12B\models--p-e-w--Mistral-Nemo-Instruct-2407-heretic-noslop" "C:\Quanter\model_cache\EldritchLabs__Nocturne-Nereid-12B-v1" import os import json import gc import re from safetensors.torch import load_file, save_file import argparse def get_weight_map(model_path): index_path = os.path.join(model_path, "model.safetensors.index.json") if os.path.exists(index_path): with open(index_path, 'r') as f: return json.load(f)["weight_map"] # Fallback for single file models res = {} for f in os.listdir(model_path): if f.endswith(".safetensors"): # This is slow but accurate for single-file models tensors = load_file(os.path.join(model_path, f)) for k in tensors.keys(): res[k] = f return res def shield_norms(base_path, merged_path): print("\n[1] Mapping tensors...") base_map = get_weight_map(base_path) merged_map = get_weight_map(merged_path) # Identify all normalization tensors # Catching: model.norm.weight, input_layernorm.weight, post_attention_layernorm.weight norm_pattern = re.compile(r".*norm\.weight$") target_tensors = [t for t in merged_map.keys() if norm_pattern.match(t)] if not target_tensors: print(" [!] No normalization tensors found!") return print(f" -> Found {len(target_tensors)} normalization tensors to shield.") # Group by shard to minimize file opening shards_to_process = {} for t in target_tensors: shard = merged_map[t] if shard not in shards_to_process: shards_to_process[shard] = [] shards_to_process[shard].append(t) print(f"\n[2] Processing {len(shards_to_process)} shards...") for shard_name, tensors in shards_to_process.items(): merged_shard_path = os.path.join(merged_path, shard_name) backup_shard_path = merged_shard_path + ".old" print(f" -> Shard: {shard_name}") # 1. Load the merged shard merged_tensors = load_file(merged_shard_path, device="cpu") # 2. Update each target tensor in this shard for t_name in tensors: if t_name in base_map: base_shard_path = os.path.join(base_path, base_map[t_name]) base_data = load_file(base_shard_path, device="cpu") print(f" Injecting pristine: {t_name}") merged_tensors[t_name] = base_data[t_name].clone() del base_data else: print(f" [!] Warning: {t_name} not found in base model. Skipping.") # 3. Atomic Rename Strategy for Windows if os.path.exists(backup_shard_path): os.remove(backup_shard_path) os.rename(merged_shard_path, backup_shard_path) try: save_file(merged_tensors, merged_shard_path, metadata={"format": "pt"}) print(f" ✅ Shard saved successfully.") except Exception as e: print(f" ❌ Error saving shard: {e}") os.rename(backup_shard_path, merged_shard_path) raise e # 4. Cleanup and release handles del merged_tensors gc.collect() try: os.remove(backup_shard_path) except Exception as e: print(f" [!] Note: Could not delete .old file immediately (OS lock). It will be orphaned: {e}") def main(): parser = argparse.ArgumentParser(description="Revert all normalization layers to base model values.") parser.add_argument("base_model", help="Path to the pristine base model") parser.add_argument("merged_model", help="Path to the merged model with artifacts") args = parser.parse_args() print("="*60) print("🛡️ NORM SHIELD: RE-CALIBRATING OUTPUT MANIFOLD") print("="*60) shield_norms(args.base_model, args.merged_model) print("\n" + "="*60) print("Done! All normalization layers have been reverted to Base.") print("This should eliminate the 'napad' / 'derrotó' vector drift.") print("="*60) if __name__ == "__main__": main()