Spaces:
Running
Running
| # python shield_norms.py "B:\12B\models--p-e-w--Mistral-Nemo-Instruct-2407-heretic-noslop" "C:\Quanter\model_cache\EldritchLabs__Nocturne-Nereid-12B-v1" | |
| import os | |
| import json | |
| import gc | |
| import re | |
| from safetensors.torch import load_file, save_file | |
| import argparse | |
| def get_weight_map(model_path): | |
| index_path = os.path.join(model_path, "model.safetensors.index.json") | |
| if os.path.exists(index_path): | |
| with open(index_path, 'r') as f: | |
| return json.load(f)["weight_map"] | |
| # Fallback for single file models | |
| res = {} | |
| for f in os.listdir(model_path): | |
| if f.endswith(".safetensors"): | |
| # This is slow but accurate for single-file models | |
| tensors = load_file(os.path.join(model_path, f)) | |
| for k in tensors.keys(): | |
| res[k] = f | |
| return res | |
| def shield_norms(base_path, merged_path): | |
| print("\n[1] Mapping tensors...") | |
| base_map = get_weight_map(base_path) | |
| merged_map = get_weight_map(merged_path) | |
| # Identify all normalization tensors | |
| # Catching: model.norm.weight, input_layernorm.weight, post_attention_layernorm.weight | |
| norm_pattern = re.compile(r".*norm\.weight$") | |
| target_tensors = [t for t in merged_map.keys() if norm_pattern.match(t)] | |
| if not target_tensors: | |
| print(" [!] No normalization tensors found!") | |
| return | |
| print(f" -> Found {len(target_tensors)} normalization tensors to shield.") | |
| # Group by shard to minimize file opening | |
| shards_to_process = {} | |
| for t in target_tensors: | |
| shard = merged_map[t] | |
| if shard not in shards_to_process: | |
| shards_to_process[shard] = [] | |
| shards_to_process[shard].append(t) | |
| print(f"\n[2] Processing {len(shards_to_process)} shards...") | |
| for shard_name, tensors in shards_to_process.items(): | |
| merged_shard_path = os.path.join(merged_path, shard_name) | |
| backup_shard_path = merged_shard_path + ".old" | |
| print(f" -> Shard: {shard_name}") | |
| # 1. Load the merged shard | |
| merged_tensors = load_file(merged_shard_path, device="cpu") | |
| # 2. Update each target tensor in this shard | |
| for t_name in tensors: | |
| if t_name in base_map: | |
| base_shard_path = os.path.join(base_path, base_map[t_name]) | |
| base_data = load_file(base_shard_path, device="cpu") | |
| print(f" Injecting pristine: {t_name}") | |
| merged_tensors[t_name] = base_data[t_name].clone() | |
| del base_data | |
| else: | |
| print(f" [!] Warning: {t_name} not found in base model. Skipping.") | |
| # 3. Atomic Rename Strategy for Windows | |
| if os.path.exists(backup_shard_path): | |
| os.remove(backup_shard_path) | |
| os.rename(merged_shard_path, backup_shard_path) | |
| try: | |
| save_file(merged_tensors, merged_shard_path, metadata={"format": "pt"}) | |
| print(f" ✅ Shard saved successfully.") | |
| except Exception as e: | |
| print(f" ❌ Error saving shard: {e}") | |
| os.rename(backup_shard_path, merged_shard_path) | |
| raise e | |
| # 4. Cleanup and release handles | |
| del merged_tensors | |
| gc.collect() | |
| try: | |
| os.remove(backup_shard_path) | |
| except Exception as e: | |
| print(f" [!] Note: Could not delete .old file immediately (OS lock). It will be orphaned: {e}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Revert all normalization layers to base model values.") | |
| parser.add_argument("base_model", help="Path to the pristine base model") | |
| parser.add_argument("merged_model", help="Path to the merged model with artifacts") | |
| args = parser.parse_args() | |
| print("="*60) | |
| print("🛡️ NORM SHIELD: RE-CALIBRATING OUTPUT MANIFOLD") | |
| print("="*60) | |
| shield_norms(args.base_model, args.merged_model) | |
| print("\n" + "="*60) | |
| print("Done! All normalization layers have been reverted to Base.") | |
| print("This should eliminate the 'napad' / 'derrotó' vector drift.") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() |