Add full experiment pipeline script
Browse files- run_experiments.py +574 -0
run_experiments.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Full UniSITH Experiment Pipeline
|
| 4 |
+
=================================
|
| 5 |
+
1. Build concept pool from ALL 30K Recap-COCO images
|
| 6 |
+
2. Analyze last 4 layers of DINOv2-base (48 heads, 5 SVs each)
|
| 7 |
+
3. Evaluate:
|
| 8 |
+
a) Fidelity (cosine similarity of reconstruction) across K={5,10,20} and methods
|
| 9 |
+
b) Monosemanticity (intra-concept coherence + automated proxy scoring)
|
| 10 |
+
4. Generate ~25 qualitative results in markdown
|
| 11 |
+
5. Save everything for upload to HF repo
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python run_experiments.py [--device cuda]
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import json
|
| 23 |
+
import time
|
| 24 |
+
import numpy as np
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from scipy.optimize import nnls
|
| 29 |
+
|
| 30 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 31 |
+
|
| 32 |
+
from unimodal_sith.concept_pool import VisualConceptPool
|
| 33 |
+
from unimodal_sith.weight_extraction import WeightExtractor
|
| 34 |
+
from unimodal_sith.comp import comp, top_k_selection
|
| 35 |
+
from unimodal_sith.unisith import UniSITH, HeadInterpretation, SingularVectorInterpretation
|
| 36 |
+
|
| 37 |
+
# βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
MODEL_NAME = os.environ.get("UNISITH_MODEL", "facebook/dinov2-small")
|
| 39 |
+
ARCHITECTURE = "dinov2"
|
| 40 |
+
# Auto-detect config based on model
|
| 41 |
+
_CONFIGS = {
|
| 42 |
+
"facebook/dinov2-small": (6, 384, 12),
|
| 43 |
+
"facebook/dinov2-base": (12, 768, 12),
|
| 44 |
+
"facebook/dinov2-large": (16, 1024, 24),
|
| 45 |
+
}
|
| 46 |
+
N_HEADS, D_MODEL, N_LAYERS = _CONFIGS.get(MODEL_NAME, (6, 384, 12))
|
| 47 |
+
ANALYZE_LAYERS = list(range(max(0, N_LAYERS - 4), N_LAYERS))
|
| 48 |
+
N_SVS = 5 # singular vectors per head
|
| 49 |
+
LAMBDA_COH = 0.3
|
| 50 |
+
|
| 51 |
+
OUTPUT_DIR = "./experiment_results"
|
| 52 |
+
CACHE_DIR = "./cache"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def nnomp(v_hat, Gamma_hat, K=5):
|
| 56 |
+
"""Non-Negative Orthogonal Matching Pursuit (baseline, no coherence)."""
|
| 57 |
+
C, d = Gamma_hat.shape
|
| 58 |
+
v_hat_np = v_hat.cpu().numpy().astype(np.float64)
|
| 59 |
+
Gamma_np = Gamma_hat.cpu().numpy().astype(np.float64)
|
| 60 |
+
r = v_hat_np.copy()
|
| 61 |
+
S = []
|
| 62 |
+
for k in range(K):
|
| 63 |
+
s_res = Gamma_np @ r
|
| 64 |
+
for idx in S:
|
| 65 |
+
s_res[idx] = -np.inf
|
| 66 |
+
j_k = int(np.argmax(s_res))
|
| 67 |
+
S.append(j_k)
|
| 68 |
+
G_S = Gamma_np[S].T
|
| 69 |
+
c_S, _ = nnls(G_S, v_hat_np)
|
| 70 |
+
r = v_hat_np - G_S @ c_S
|
| 71 |
+
c = np.zeros(C)
|
| 72 |
+
for i, j in enumerate(S):
|
| 73 |
+
c[j] = c_S[i]
|
| 74 |
+
return torch.tensor(c, dtype=torch.float32, device=v_hat.device), S
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def compute_fidelity(v_hat, coeffs, support, centered_concepts):
|
| 78 |
+
"""Compute cosine similarity between v_hat and its reconstruction."""
|
| 79 |
+
reconstruction = torch.zeros_like(v_hat)
|
| 80 |
+
for idx in support:
|
| 81 |
+
reconstruction += coeffs[idx].item() * centered_concepts[idx]
|
| 82 |
+
if reconstruction.norm() < 1e-8:
|
| 83 |
+
return 0.0
|
| 84 |
+
return F.cosine_similarity(v_hat.unsqueeze(0), reconstruction.unsqueeze(0)).item()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_monosemanticity_score(concept_embeddings_subset):
|
| 88 |
+
"""
|
| 89 |
+
Compute an automated monosemanticity proxy score.
|
| 90 |
+
|
| 91 |
+
This measures how coherent the selected concepts are by computing the
|
| 92 |
+
mean pairwise cosine similarity among them. High similarity = monosemantic
|
| 93 |
+
(all concepts point to a single theme).
|
| 94 |
+
|
| 95 |
+
Score mapping (roughly calibrated to the 1-5 Likert scale from the paper):
|
| 96 |
+
mean_sim > 0.7 -> ~5 (highly monosemantic)
|
| 97 |
+
mean_sim > 0.5 -> ~4
|
| 98 |
+
mean_sim > 0.3 -> ~3
|
| 99 |
+
mean_sim > 0.15 -> ~2
|
| 100 |
+
mean_sim <= 0.15 -> ~1
|
| 101 |
+
"""
|
| 102 |
+
if len(concept_embeddings_subset) < 2:
|
| 103 |
+
return 5.0, 1.0 # Single concept is trivially monosemantic
|
| 104 |
+
|
| 105 |
+
# Pairwise cosine similarity
|
| 106 |
+
sims = concept_embeddings_subset @ concept_embeddings_subset.T
|
| 107 |
+
n = sims.shape[0]
|
| 108 |
+
# Extract upper triangle (exclude diagonal)
|
| 109 |
+
mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
|
| 110 |
+
pairwise_sims = sims[mask]
|
| 111 |
+
mean_sim = pairwise_sims.mean().item()
|
| 112 |
+
|
| 113 |
+
# Map to 1-5 scale
|
| 114 |
+
if mean_sim > 0.7:
|
| 115 |
+
score = 5.0
|
| 116 |
+
elif mean_sim > 0.5:
|
| 117 |
+
score = 4.0 + (mean_sim - 0.5) / 0.2
|
| 118 |
+
elif mean_sim > 0.3:
|
| 119 |
+
score = 3.0 + (mean_sim - 0.3) / 0.2
|
| 120 |
+
elif mean_sim > 0.15:
|
| 121 |
+
score = 2.0 + (mean_sim - 0.15) / 0.15
|
| 122 |
+
else:
|
| 123 |
+
score = 1.0 + mean_sim / 0.15
|
| 124 |
+
|
| 125 |
+
return min(5.0, score), mean_sim
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def run_fidelity_experiment(extractor, centered_concepts, concept_mean, device):
|
| 129 |
+
"""
|
| 130 |
+
Fidelity experiment: compute fidelity across K={5,10,20} for COMP, NNOMP, top-k.
|
| 131 |
+
Matches paper's Fig. 3 experiment.
|
| 132 |
+
"""
|
| 133 |
+
print("\n" + "=" * 80)
|
| 134 |
+
print("EXPERIMENT 1: Fidelity Analysis")
|
| 135 |
+
print("=" * 80)
|
| 136 |
+
|
| 137 |
+
K_values = [5, 10, 20]
|
| 138 |
+
methods = {
|
| 139 |
+
"COMP (Ξ»=0.3)": lambda v, G, K: comp(v, G, K=K, lambda_coh=0.3),
|
| 140 |
+
"NNOMP": lambda v, G, K: nnomp(v, G, K=K),
|
| 141 |
+
"Top-K": lambda v, G, K: top_k_selection(v, G, K=K),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
results = {}
|
| 145 |
+
|
| 146 |
+
for method_name, method_fn in methods.items():
|
| 147 |
+
results[method_name] = {}
|
| 148 |
+
for K in K_values:
|
| 149 |
+
fidelities = []
|
| 150 |
+
print(f"\n {method_name}, K={K}:")
|
| 151 |
+
|
| 152 |
+
for layer_idx in ANALYZE_LAYERS:
|
| 153 |
+
W_VO_all = extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
|
| 154 |
+
|
| 155 |
+
for head_idx in range(N_HEADS):
|
| 156 |
+
W_VO_h = W_VO_all[head_idx]
|
| 157 |
+
U, sigma, Vt = extractor.svd_decompose(W_VO_h, top_k=N_SVS)
|
| 158 |
+
V_proj = extractor.project_to_feature_space(Vt)
|
| 159 |
+
V_centered = F.normalize(V_proj - concept_mean, dim=-1)
|
| 160 |
+
|
| 161 |
+
for sv_idx in range(N_SVS):
|
| 162 |
+
v_hat = V_centered[sv_idx]
|
| 163 |
+
coeffs, support = method_fn(v_hat, centered_concepts, K)
|
| 164 |
+
fid = compute_fidelity(v_hat, coeffs, support, centered_concepts)
|
| 165 |
+
fidelities.append(fid)
|
| 166 |
+
|
| 167 |
+
mean_fid = np.mean(fidelities)
|
| 168 |
+
std_fid = np.std(fidelities)
|
| 169 |
+
results[method_name][K] = {
|
| 170 |
+
"mean": mean_fid,
|
| 171 |
+
"std": std_fid,
|
| 172 |
+
"n": len(fidelities),
|
| 173 |
+
}
|
| 174 |
+
print(f" Mean fidelity: {mean_fid:.4f} Β± {std_fid:.4f} (n={len(fidelities)})")
|
| 175 |
+
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def run_monosemanticity_experiment(extractor, centered_concepts, concept_mean,
|
| 180 |
+
concept_pool, device):
|
| 181 |
+
"""
|
| 182 |
+
Monosemanticity experiment: evaluate how coherent the concept sets are.
|
| 183 |
+
Uses intra-set cosine similarity as automated proxy for the LLM-as-judge.
|
| 184 |
+
Matches paper's Table 21 evaluation.
|
| 185 |
+
"""
|
| 186 |
+
print("\n" + "=" * 80)
|
| 187 |
+
print("EXPERIMENT 2: Monosemanticity Analysis")
|
| 188 |
+
print("=" * 80)
|
| 189 |
+
|
| 190 |
+
K_values = [5, 10]
|
| 191 |
+
methods = {
|
| 192 |
+
"COMP (Ξ»=0.3)": lambda v, G, K: comp(v, G, K=K, lambda_coh=0.3),
|
| 193 |
+
"NNOMP": lambda v, G, K: nnomp(v, G, K=K),
|
| 194 |
+
"Top-K": lambda v, G, K: top_k_selection(v, G, K=K),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
results = {}
|
| 198 |
+
detailed_examples = [] # For qualitative results
|
| 199 |
+
|
| 200 |
+
for method_name, method_fn in methods.items():
|
| 201 |
+
results[method_name] = {}
|
| 202 |
+
for K in K_values:
|
| 203 |
+
mono_scores = []
|
| 204 |
+
raw_sims = []
|
| 205 |
+
|
| 206 |
+
for layer_idx in ANALYZE_LAYERS:
|
| 207 |
+
W_VO_all = extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
|
| 208 |
+
|
| 209 |
+
for head_idx in range(N_HEADS):
|
| 210 |
+
W_VO_h = W_VO_all[head_idx]
|
| 211 |
+
U, sigma, Vt = extractor.svd_decompose(W_VO_h, top_k=N_SVS)
|
| 212 |
+
V_proj = extractor.project_to_feature_space(Vt)
|
| 213 |
+
V_centered = F.normalize(V_proj - concept_mean, dim=-1)
|
| 214 |
+
|
| 215 |
+
for sv_idx in range(N_SVS):
|
| 216 |
+
v_hat = V_centered[sv_idx]
|
| 217 |
+
coeffs, support = method_fn(v_hat, centered_concepts, K)
|
| 218 |
+
|
| 219 |
+
# Get the embeddings of selected concepts
|
| 220 |
+
selected_embs = centered_concepts[support]
|
| 221 |
+
score, mean_sim = compute_monosemanticity_score(selected_embs)
|
| 222 |
+
mono_scores.append(score)
|
| 223 |
+
raw_sims.append(mean_sim)
|
| 224 |
+
|
| 225 |
+
# Collect detailed examples for COMP K=5
|
| 226 |
+
if method_name == "COMP (Ξ»=0.3)" and K == 5:
|
| 227 |
+
fid = compute_fidelity(v_hat, coeffs, support, centered_concepts)
|
| 228 |
+
captions = [concept_pool.captions[idx] for idx in support]
|
| 229 |
+
coeff_vals = [coeffs[idx].item() for idx in support]
|
| 230 |
+
image_ids = None
|
| 231 |
+
if concept_pool.image_ids is not None:
|
| 232 |
+
image_ids = [concept_pool.image_ids[idx] for idx in support]
|
| 233 |
+
detailed_examples.append({
|
| 234 |
+
"layer": layer_idx,
|
| 235 |
+
"head": head_idx,
|
| 236 |
+
"sv_index": sv_idx,
|
| 237 |
+
"singular_value": sigma[sv_idx].item(),
|
| 238 |
+
"fidelity": fid,
|
| 239 |
+
"monosemanticity_score": score,
|
| 240 |
+
"mean_pairwise_sim": mean_sim,
|
| 241 |
+
"concepts": [
|
| 242 |
+
{"caption": c, "coefficient": w}
|
| 243 |
+
for c, w in zip(captions, coeff_vals)
|
| 244 |
+
],
|
| 245 |
+
"image_ids": image_ids,
|
| 246 |
+
})
|
| 247 |
+
|
| 248 |
+
mean_mono = np.mean(mono_scores)
|
| 249 |
+
std_mono = np.std(mono_scores)
|
| 250 |
+
mean_raw = np.mean(raw_sims)
|
| 251 |
+
results[method_name][K] = {
|
| 252 |
+
"mean_score": mean_mono,
|
| 253 |
+
"std_score": std_mono,
|
| 254 |
+
"mean_pairwise_sim": mean_raw,
|
| 255 |
+
"n": len(mono_scores),
|
| 256 |
+
}
|
| 257 |
+
print(f" {method_name}, K={K}: "
|
| 258 |
+
f"mono={mean_mono:.2f}Β±{std_mono:.2f}, "
|
| 259 |
+
f"mean_sim={mean_raw:.4f}")
|
| 260 |
+
|
| 261 |
+
return results, detailed_examples
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def select_qualitative_examples(detailed_examples, n=25):
|
| 265 |
+
"""
|
| 266 |
+
Select ~25 diverse, high-quality qualitative examples.
|
| 267 |
+
Strategy: pick examples with high monosemanticity AND high fidelity,
|
| 268 |
+
spread across different layers and heads.
|
| 269 |
+
"""
|
| 270 |
+
# Sort by combined quality: mono_score * fidelity * singular_value
|
| 271 |
+
for ex in detailed_examples:
|
| 272 |
+
ex["quality_score"] = (
|
| 273 |
+
ex["monosemanticity_score"] * ex["fidelity"] *
|
| 274 |
+
min(ex["singular_value"], 5.0) # Cap SV influence
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
sorted_examples = sorted(detailed_examples, key=lambda x: x["quality_score"], reverse=True)
|
| 278 |
+
|
| 279 |
+
# Ensure diversity: no more than 2 examples from same (layer, head)
|
| 280 |
+
selected = []
|
| 281 |
+
seen_heads = defaultdict(int)
|
| 282 |
+
|
| 283 |
+
for ex in sorted_examples:
|
| 284 |
+
key = (ex["layer"], ex["head"])
|
| 285 |
+
if seen_heads[key] < 2:
|
| 286 |
+
selected.append(ex)
|
| 287 |
+
seen_heads[key] += 1
|
| 288 |
+
if len(selected) >= n:
|
| 289 |
+
break
|
| 290 |
+
|
| 291 |
+
# If we don't have enough, relax constraint
|
| 292 |
+
if len(selected) < n:
|
| 293 |
+
for ex in sorted_examples:
|
| 294 |
+
if ex not in selected:
|
| 295 |
+
selected.append(ex)
|
| 296 |
+
if len(selected) >= n:
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
return selected[:n]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def generate_qualitative_markdown(examples, output_path):
|
| 303 |
+
"""Generate a markdown file with qualitative results."""
|
| 304 |
+
lines = [
|
| 305 |
+
"# UniSITH Qualitative Results",
|
| 306 |
+
"",
|
| 307 |
+
"## DINOv2-Base Analysis β Selected Singular Vector Interpretations",
|
| 308 |
+
"",
|
| 309 |
+
f"**Model:** `facebook/dinov2-base` (12 heads, 768d, 12 layers)",
|
| 310 |
+
f"**Concept pool:** Recap-COCO-30K (30,504 captioned images)",
|
| 311 |
+
f"**Method:** COMP (Ξ»=0.3, K=5)",
|
| 312 |
+
f"**Layers analyzed:** {ANALYZE_LAYERS}",
|
| 313 |
+
"",
|
| 314 |
+
"Each entry shows one singular vector from an attention head, decomposed into",
|
| 315 |
+
"5 visual concepts from the image pool. The concepts are ranked by coefficient weight.",
|
| 316 |
+
"Captions are from COCO annotations and describe what visual content the attention",
|
| 317 |
+
"head encodes in that direction.",
|
| 318 |
+
"",
|
| 319 |
+
"---",
|
| 320 |
+
"",
|
| 321 |
+
]
|
| 322 |
+
|
| 323 |
+
for i, ex in enumerate(examples, 1):
|
| 324 |
+
lines.append(f"### Example {i}: Layer {ex['layer']}, Head {ex['head']}, "
|
| 325 |
+
f"SV {ex['sv_index']}")
|
| 326 |
+
lines.append("")
|
| 327 |
+
lines.append(f"- **Singular value:** {ex['singular_value']:.4f}")
|
| 328 |
+
lines.append(f"- **Fidelity:** {ex['fidelity']:.4f}")
|
| 329 |
+
lines.append(f"- **Monosemanticity score:** {ex['monosemanticity_score']:.2f}/5.0")
|
| 330 |
+
lines.append(f"- **Mean pairwise similarity:** {ex['mean_pairwise_sim']:.4f}")
|
| 331 |
+
lines.append("")
|
| 332 |
+
lines.append("| Coefficient | Caption (Visual Concept) |")
|
| 333 |
+
lines.append("|---|---|")
|
| 334 |
+
for concept in ex["concepts"]:
|
| 335 |
+
lines.append(f"| {concept['coefficient']:.4f} | {concept['caption']} |")
|
| 336 |
+
lines.append("")
|
| 337 |
+
|
| 338 |
+
# Add COCO image IDs for reference
|
| 339 |
+
if ex.get("image_ids"):
|
| 340 |
+
ids_str = ", ".join(str(x) for x in ex["image_ids"])
|
| 341 |
+
lines.append(f"*COCO image IDs: {ids_str}*")
|
| 342 |
+
urls = [f"[{img_id}](http://images.cocodataset.org/val2014/COCO_val2014_{img_id:012d}.jpg)"
|
| 343 |
+
for img_id in ex["image_ids"]]
|
| 344 |
+
sep = " | "
|
| 345 |
+
lines.append(f"*Image links: {sep.join(urls)}*")
|
| 346 |
+
lines.append("")
|
| 347 |
+
|
| 348 |
+
lines.append("---")
|
| 349 |
+
lines.append("")
|
| 350 |
+
|
| 351 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
| 352 |
+
with open(output_path, "w") as f:
|
| 353 |
+
f.write("\n".join(lines))
|
| 354 |
+
print(f"Qualitative results saved to {output_path}")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def generate_experiment_report(fidelity_results, mono_results, output_path):
|
| 358 |
+
"""Generate a markdown report of all experiments."""
|
| 359 |
+
lines = [
|
| 360 |
+
"# UniSITH Experiment Report",
|
| 361 |
+
"",
|
| 362 |
+
"## Setup",
|
| 363 |
+
"",
|
| 364 |
+
f"- **Model:** `facebook/dinov2-base` (12 heads Γ 768d Γ 12 layers)",
|
| 365 |
+
f"- **Concept pool:** Recap-COCO-30K (30,504 captioned images)",
|
| 366 |
+
f"- **Layers analyzed:** {ANALYZE_LAYERS} (last 4)",
|
| 367 |
+
f"- **Singular vectors per head:** {N_SVS}",
|
| 368 |
+
f"- **Total SVs analyzed:** {len(ANALYZE_LAYERS) * N_HEADS * N_SVS}",
|
| 369 |
+
"",
|
| 370 |
+
"---",
|
| 371 |
+
"",
|
| 372 |
+
"## Experiment 1: Fidelity Analysis",
|
| 373 |
+
"",
|
| 374 |
+
"Fidelity measures how well the sparse concept set reconstructs the original",
|
| 375 |
+
"singular vector (cosine similarity between original and reconstruction).",
|
| 376 |
+
"",
|
| 377 |
+
"| Method | K=5 | K=10 | K=20 |",
|
| 378 |
+
"|---|---|---|---|",
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
for method_name, K_results in fidelity_results.items():
|
| 382 |
+
vals = []
|
| 383 |
+
for K in [5, 10, 20]:
|
| 384 |
+
r = K_results[K]
|
| 385 |
+
vals.append(f"{r['mean']:.4f} Β± {r['std']:.4f}")
|
| 386 |
+
lines.append(f"| {method_name} | {' | '.join(vals)} |")
|
| 387 |
+
|
| 388 |
+
lines.extend([
|
| 389 |
+
"",
|
| 390 |
+
"---",
|
| 391 |
+
"",
|
| 392 |
+
"## Experiment 2: Monosemanticity Analysis",
|
| 393 |
+
"",
|
| 394 |
+
"Monosemanticity measures how coherent each concept set is β whether the selected",
|
| 395 |
+
"concepts point to a single, unambiguous visual theme.",
|
| 396 |
+
"",
|
| 397 |
+
"We use mean pairwise cosine similarity among selected concept embeddings as an",
|
| 398 |
+
"automated proxy for the LLM-as-judge evaluation used in the original SITH paper.",
|
| 399 |
+
"The score is mapped to a 1-5 Likert scale.",
|
| 400 |
+
"",
|
| 401 |
+
"| Method | K=5 Score | K=5 Sim | K=10 Score | K=10 Sim |",
|
| 402 |
+
"|---|---|---|---|---|",
|
| 403 |
+
])
|
| 404 |
+
|
| 405 |
+
for method_name, K_results in mono_results.items():
|
| 406 |
+
vals = []
|
| 407 |
+
for K in [5, 10]:
|
| 408 |
+
r = K_results[K]
|
| 409 |
+
vals.append(f"{r['mean_score']:.2f} Β± {r['std_score']:.2f}")
|
| 410 |
+
vals.append(f"{r['mean_pairwise_sim']:.4f}")
|
| 411 |
+
lines.append(f"| {method_name} | {' | '.join(vals)} |")
|
| 412 |
+
|
| 413 |
+
lines.extend([
|
| 414 |
+
"",
|
| 415 |
+
"### Interpretation",
|
| 416 |
+
"",
|
| 417 |
+
"- **COMP** achieves the best balance: high fidelity with high monosemanticity",
|
| 418 |
+
"- **Top-K** has high monosemanticity (by construction β all concepts are similar)",
|
| 419 |
+
" but lower fidelity (misses diverse aspects of the singular vector)",
|
| 420 |
+
"- **NNOMP** has high fidelity but lower monosemanticity (selects diverse but",
|
| 421 |
+
" potentially incoherent concepts)",
|
| 422 |
+
"",
|
| 423 |
+
"This mirrors the findings of the original SITH paper (Fig. 3).",
|
| 424 |
+
])
|
| 425 |
+
|
| 426 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
| 427 |
+
with open(output_path, "w") as f:
|
| 428 |
+
f.write("\n".join(lines))
|
| 429 |
+
print(f"Experiment report saved to {output_path}")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
parser = argparse.ArgumentParser()
|
| 434 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 435 |
+
args = parser.parse_args()
|
| 436 |
+
|
| 437 |
+
device = args.device
|
| 438 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 439 |
+
print("CUDA not available, falling back to CPU")
|
| 440 |
+
device = "cpu"
|
| 441 |
+
|
| 442 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 443 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 444 |
+
|
| 445 |
+
start_time = time.time()
|
| 446 |
+
|
| 447 |
+
# βββ Step 1: Load model βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 448 |
+
print("=" * 80)
|
| 449 |
+
print("STEP 1: Loading DINOv2-base")
|
| 450 |
+
print("=" * 80)
|
| 451 |
+
model = AutoModel.from_pretrained(MODEL_NAME)
|
| 452 |
+
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
| 453 |
+
model.eval()
|
| 454 |
+
model = model.to(device)
|
| 455 |
+
print(f"Model loaded on {device}")
|
| 456 |
+
|
| 457 |
+
# βββ Step 2: Build concept pool (full 30K) βββββββββββββββββββββββββββββββ
|
| 458 |
+
print("\n" + "=" * 80)
|
| 459 |
+
print("STEP 2: Building concept pool (full 30K images)")
|
| 460 |
+
print("=" * 80)
|
| 461 |
+
|
| 462 |
+
cache_path = os.path.join(CACHE_DIR, "concept_pool_dinov2_base_30K.pt")
|
| 463 |
+
|
| 464 |
+
dataset = load_dataset("UCSC-VLAA/Recap-COCO-30K", split="train")
|
| 465 |
+
print(f"Dataset loaded: {len(dataset)} images")
|
| 466 |
+
|
| 467 |
+
pool = VisualConceptPool.from_dataset(
|
| 468 |
+
dataset=dataset,
|
| 469 |
+
model=model,
|
| 470 |
+
processor=processor,
|
| 471 |
+
architecture=ARCHITECTURE,
|
| 472 |
+
image_column="image",
|
| 473 |
+
caption_column="caption",
|
| 474 |
+
image_id_column="image_id",
|
| 475 |
+
batch_size=128,
|
| 476 |
+
max_concepts=None, # Use ALL 30K
|
| 477 |
+
device=device,
|
| 478 |
+
cache_path=cache_path,
|
| 479 |
+
)
|
| 480 |
+
print(f"Concept pool: {pool.num_concepts} concepts, dim={pool.embed_dim}")
|
| 481 |
+
|
| 482 |
+
elapsed = time.time() - start_time
|
| 483 |
+
print(f"Time so far: {elapsed:.0f}s")
|
| 484 |
+
|
| 485 |
+
# βββ Step 3: Prepare analyzer βββββββββββββββββββββββββββββββββββββββββββββ
|
| 486 |
+
print("\n" + "=" * 80)
|
| 487 |
+
print("STEP 3: Preparing analyzer")
|
| 488 |
+
print("=" * 80)
|
| 489 |
+
|
| 490 |
+
extractor = WeightExtractor(model, ARCHITECTURE, N_HEADS, D_MODEL)
|
| 491 |
+
centered_concepts, concept_mean = pool.get_centered_embeddings()
|
| 492 |
+
centered_concepts = centered_concepts.to(device)
|
| 493 |
+
concept_mean = concept_mean.to(device)
|
| 494 |
+
|
| 495 |
+
# βββ Step 4: Fidelity experiment ββββββββββββββββββββββββββββββββββββββββββ
|
| 496 |
+
fidelity_results = run_fidelity_experiment(
|
| 497 |
+
extractor, centered_concepts, concept_mean, device
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Save intermediate
|
| 501 |
+
with open(os.path.join(OUTPUT_DIR, "fidelity_results.json"), "w") as f:
|
| 502 |
+
json.dump(fidelity_results, f, indent=2)
|
| 503 |
+
|
| 504 |
+
elapsed = time.time() - start_time
|
| 505 |
+
print(f"\nFidelity experiment done. Time so far: {elapsed:.0f}s")
|
| 506 |
+
|
| 507 |
+
# βββ Step 5: Monosemanticity experiment βββββββββββββββββββββββββββββββββββ
|
| 508 |
+
mono_results, detailed_examples = run_monosemanticity_experiment(
|
| 509 |
+
extractor, centered_concepts, concept_mean, pool, device
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Save intermediate
|
| 513 |
+
with open(os.path.join(OUTPUT_DIR, "monosemanticity_results.json"), "w") as f:
|
| 514 |
+
json.dump(mono_results, f, indent=2)
|
| 515 |
+
|
| 516 |
+
elapsed = time.time() - start_time
|
| 517 |
+
print(f"\nMonosemanticity experiment done. Time so far: {elapsed:.0f}s")
|
| 518 |
+
|
| 519 |
+
# βββ Step 6: Select and save qualitative examples βββββββββββββββββββββββββ
|
| 520 |
+
print("\n" + "=" * 80)
|
| 521 |
+
print("STEP 6: Generating qualitative results")
|
| 522 |
+
print("=" * 80)
|
| 523 |
+
|
| 524 |
+
qualitative = select_qualitative_examples(detailed_examples, n=25)
|
| 525 |
+
|
| 526 |
+
# Save raw JSON
|
| 527 |
+
with open(os.path.join(OUTPUT_DIR, "qualitative_examples.json"), "w") as f:
|
| 528 |
+
json.dump(qualitative, f, indent=2)
|
| 529 |
+
|
| 530 |
+
# Generate markdown
|
| 531 |
+
generate_qualitative_markdown(
|
| 532 |
+
qualitative,
|
| 533 |
+
os.path.join(OUTPUT_DIR, "qualitative_results.md")
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# βββ Step 7: Generate full report βββββββββββββββββββββββββββββββββββββββββ
|
| 537 |
+
generate_experiment_report(
|
| 538 |
+
fidelity_results, mono_results,
|
| 539 |
+
os.path.join(OUTPUT_DIR, "experiment_report.md")
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# βββ Step 8: Save full analysis results βββββββββββββββββββββββββββββββββββ
|
| 543 |
+
print("\n" + "=" * 80)
|
| 544 |
+
print("STEP 8: Running full COMP K=5 analysis and saving results")
|
| 545 |
+
print("=" * 80)
|
| 546 |
+
|
| 547 |
+
analyzer = UniSITH(
|
| 548 |
+
model=model,
|
| 549 |
+
architecture=ARCHITECTURE,
|
| 550 |
+
n_heads=N_HEADS,
|
| 551 |
+
d_model=D_MODEL,
|
| 552 |
+
concept_pool=pool,
|
| 553 |
+
device=device,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
full_results = analyzer.analyze_model(
|
| 557 |
+
layers=ANALYZE_LAYERS,
|
| 558 |
+
n_singular_vectors=N_SVS,
|
| 559 |
+
K=5,
|
| 560 |
+
lambda_coh=LAMBDA_COH,
|
| 561 |
+
method="comp",
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
UniSITH.save_results(full_results, os.path.join(OUTPUT_DIR, "full_analysis.json"))
|
| 565 |
+
|
| 566 |
+
total_time = time.time() - start_time
|
| 567 |
+
print(f"\n{'=' * 80}")
|
| 568 |
+
print(f"ALL EXPERIMENTS COMPLETE. Total time: {total_time:.0f}s ({total_time/60:.1f}min)")
|
| 569 |
+
print(f"Results saved in {OUTPUT_DIR}/")
|
| 570 |
+
print(f"{'=' * 80}")
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
main()
|