model_tools / arcee_fusion_salience_scanner_v3.py
Naphula's picture
Upload 5 files
6a2122d verified
import torch
from safetensors import safe_open
import os
import re
from collections import defaultdict
# --- CONFIGURATION ---
base_model_path = r'B:\12B\models--SicariusSicariiStuff--Impish_Bloodmoon_12B'
merged_model_path = r'B:\12B\21-Della'
# ---------------------
def get_tensor_map(path):
tensor_map = {}
files = [f for f in os.listdir(path) if f.endswith('.safetensors')]
for f in files:
full_path = os.path.join(path, f)
with safe_open(full_path, framework="pt") as st:
for k in st.keys():
tensor_map[k] = full_path
return tensor_map
print("πŸ” Indexing model shards...")
base_map = get_tensor_map(base_model_path)
merged_map = get_tensor_map(merged_model_path)
# Group results by layer
layer_stats = defaultdict(lambda: {"changed": 0, "total": 0})
print("πŸ“Š Scanning tensors and calculating saliency density...")
common_tensors = set(base_map.keys()) & set(merged_map.keys())
for k in sorted(common_tensors):
# Extract layer number from name (e.g., 'model.layers.5.self_attn...')
layer_match = re.search(r'\.layers\.(\0?(\d+))\.', k)
layer_id = int(layer_match.group(1)) if layer_match else "Non-Layer"
# with safe_open(base_map[k], framework="pt") as b_st:
# base_t = b_st.get_tensor(k)
# with safe_open(merged_map[k], framework="pt") as m_st:
# merged_t = m_st.get_tensor(k)
## Arcee Fusion logic: if weights are identical, they came from Base.
## If they are different, they are "New Info" from the fusion.
## We use a tiny atol to account for potential bf16/f16 casting jitters
# changed_mask = ~torch.isclose(base_t, merged_t, rtol=1e-05, atol=1e-08)
with safe_open(base_map[k], framework="pt") as b_st:
base_t = b_st.get_tensor(k)
with safe_open(merged_map[k], framework="pt") as m_st:
merged_t = m_st.get_tensor(k)
# --- VOCAB SIZE ROBUSTNESS PATCH ---
if base_t.shape != merged_t.shape:
# Find the smallest dimensions common to both
min_dim0 = min(base_t.shape[0], merged_t.shape[0])
# If it's a 2D tensor (like embeddings), handle both dims
if base_t.dim() > 1:
min_dim1 = min(base_t.shape[1], merged_t.shape[1])
base_t = base_t[:min_dim0, :min_dim1]
merged_t = merged_t[:min_dim0, :min_dim1]
else:
base_t = base_t[:min_dim0]
merged_t = merged_t[:min_dim0]
print(f" [!] Resized {k} from {list(merged_t.shape)} to {min_dim0} for comparison.")
# -----------------------------------
## # Arcee Fusion logic: if weights are identical, they came from Base.
## changed_mask = ~torch.isclose(base_t, merged_t, rtol=1e-05, atol=1e-08)
# If you want to be strict (only see major changes): Use diff > (0.1 * torch.abs(base_t)) (10% change).
# If you want to be balanced: Use the "diff > (1e-3 + 0.05 * torch.abs(base_t)) " code (5% change).
# --- SIGNIFICANT CHANGE LOGIC (Salience) ---
# Instead of looking for ANY change, we look for changes that exceed
# a standard deviation threshold. This filters out the "DELLA noise."
diff = torch.abs(base_t - merged_t)
threshold = 0.01 # Adjust this: 0.01 = 1% absolute change, 0.005 = 0.5%
# Alternatively, use a relative threshold for more precision:
# We consider it "New Info" only if the change is significant
# compared to the original weight magnitude.
changed_mask = diff > (1e-3 + 0.05 * torch.abs(base_t))
# -------------------------------------------
layer_stats[layer_id]["changed"] += torch.count_nonzero(changed_mask).item()
layer_stats[layer_id]["changed"] += torch.count_nonzero(changed_mask).item()
layer_stats[layer_id]["total"] += merged_t.numel()
print("\n" + "="*60)
print(f"{'LAYER':<12} | {'NEW INFO %':<12} | {'VISUAL DENSITY (β–ˆ = New, β–‘ = Base)'}")
print("="*60)
# Sort layers: Non-Layer first, then 0, 1, 2...
sorted_keys = sorted([k for k in layer_stats.keys() if isinstance(k, int)])
if "Non-Layer" in layer_stats:
sorted_keys = ["Non-Layer"] + sorted_keys
for lid in sorted_keys:
stats = layer_stats[lid]
percentage = (stats["changed"] / stats["total"]) * 100
# Create ASCII bar
bar_width = 30
filled = int((percentage / 100) * bar_width)
bar = "β–ˆ" * filled + "β–‘" * (bar_width - filled)
label = f"Layer {lid}" if isinstance(lid, int) else lid
print(f"{label:<12} | {percentage:>10.2f}% | {bar}")
print("="*60)
print("Analysis Complete.")