lorenzovaquero commited on
Commit
b7e10fc
Β·
verified Β·
1 Parent(s): 4c50b12

Add full experiment pipeline script

Browse files
Files changed (1) hide show
  1. run_experiments.py +574 -0
run_experiments.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full UniSITH Experiment Pipeline
4
+ =================================
5
+ 1. Build concept pool from ALL 30K Recap-COCO images
6
+ 2. Analyze last 4 layers of DINOv2-base (48 heads, 5 SVs each)
7
+ 3. Evaluate:
8
+ a) Fidelity (cosine similarity of reconstruction) across K={5,10,20} and methods
9
+ b) Monosemanticity (intra-concept coherence + automated proxy scoring)
10
+ 4. Generate ~25 qualitative results in markdown
11
+ 5. Save everything for upload to HF repo
12
+
13
+ Usage:
14
+ python run_experiments.py [--device cuda]
15
+ """
16
+
17
+ import argparse
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import os
21
+ import sys
22
+ import json
23
+ import time
24
+ import numpy as np
25
+ from collections import defaultdict
26
+ from transformers import AutoModel, AutoImageProcessor
27
+ from datasets import load_dataset
28
+ from scipy.optimize import nnls
29
+
30
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
31
+
32
+ from unimodal_sith.concept_pool import VisualConceptPool
33
+ from unimodal_sith.weight_extraction import WeightExtractor
34
+ from unimodal_sith.comp import comp, top_k_selection
35
+ from unimodal_sith.unisith import UniSITH, HeadInterpretation, SingularVectorInterpretation
36
+
37
+ # ─── Config ───────────────────────────────────────────────────────────────────
38
+ MODEL_NAME = os.environ.get("UNISITH_MODEL", "facebook/dinov2-small")
39
+ ARCHITECTURE = "dinov2"
40
+ # Auto-detect config based on model
41
+ _CONFIGS = {
42
+ "facebook/dinov2-small": (6, 384, 12),
43
+ "facebook/dinov2-base": (12, 768, 12),
44
+ "facebook/dinov2-large": (16, 1024, 24),
45
+ }
46
+ N_HEADS, D_MODEL, N_LAYERS = _CONFIGS.get(MODEL_NAME, (6, 384, 12))
47
+ ANALYZE_LAYERS = list(range(max(0, N_LAYERS - 4), N_LAYERS))
48
+ N_SVS = 5 # singular vectors per head
49
+ LAMBDA_COH = 0.3
50
+
51
+ OUTPUT_DIR = "./experiment_results"
52
+ CACHE_DIR = "./cache"
53
+
54
+
55
+ def nnomp(v_hat, Gamma_hat, K=5):
56
+ """Non-Negative Orthogonal Matching Pursuit (baseline, no coherence)."""
57
+ C, d = Gamma_hat.shape
58
+ v_hat_np = v_hat.cpu().numpy().astype(np.float64)
59
+ Gamma_np = Gamma_hat.cpu().numpy().astype(np.float64)
60
+ r = v_hat_np.copy()
61
+ S = []
62
+ for k in range(K):
63
+ s_res = Gamma_np @ r
64
+ for idx in S:
65
+ s_res[idx] = -np.inf
66
+ j_k = int(np.argmax(s_res))
67
+ S.append(j_k)
68
+ G_S = Gamma_np[S].T
69
+ c_S, _ = nnls(G_S, v_hat_np)
70
+ r = v_hat_np - G_S @ c_S
71
+ c = np.zeros(C)
72
+ for i, j in enumerate(S):
73
+ c[j] = c_S[i]
74
+ return torch.tensor(c, dtype=torch.float32, device=v_hat.device), S
75
+
76
+
77
+ def compute_fidelity(v_hat, coeffs, support, centered_concepts):
78
+ """Compute cosine similarity between v_hat and its reconstruction."""
79
+ reconstruction = torch.zeros_like(v_hat)
80
+ for idx in support:
81
+ reconstruction += coeffs[idx].item() * centered_concepts[idx]
82
+ if reconstruction.norm() < 1e-8:
83
+ return 0.0
84
+ return F.cosine_similarity(v_hat.unsqueeze(0), reconstruction.unsqueeze(0)).item()
85
+
86
+
87
+ def compute_monosemanticity_score(concept_embeddings_subset):
88
+ """
89
+ Compute an automated monosemanticity proxy score.
90
+
91
+ This measures how coherent the selected concepts are by computing the
92
+ mean pairwise cosine similarity among them. High similarity = monosemantic
93
+ (all concepts point to a single theme).
94
+
95
+ Score mapping (roughly calibrated to the 1-5 Likert scale from the paper):
96
+ mean_sim > 0.7 -> ~5 (highly monosemantic)
97
+ mean_sim > 0.5 -> ~4
98
+ mean_sim > 0.3 -> ~3
99
+ mean_sim > 0.15 -> ~2
100
+ mean_sim <= 0.15 -> ~1
101
+ """
102
+ if len(concept_embeddings_subset) < 2:
103
+ return 5.0, 1.0 # Single concept is trivially monosemantic
104
+
105
+ # Pairwise cosine similarity
106
+ sims = concept_embeddings_subset @ concept_embeddings_subset.T
107
+ n = sims.shape[0]
108
+ # Extract upper triangle (exclude diagonal)
109
+ mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
110
+ pairwise_sims = sims[mask]
111
+ mean_sim = pairwise_sims.mean().item()
112
+
113
+ # Map to 1-5 scale
114
+ if mean_sim > 0.7:
115
+ score = 5.0
116
+ elif mean_sim > 0.5:
117
+ score = 4.0 + (mean_sim - 0.5) / 0.2
118
+ elif mean_sim > 0.3:
119
+ score = 3.0 + (mean_sim - 0.3) / 0.2
120
+ elif mean_sim > 0.15:
121
+ score = 2.0 + (mean_sim - 0.15) / 0.15
122
+ else:
123
+ score = 1.0 + mean_sim / 0.15
124
+
125
+ return min(5.0, score), mean_sim
126
+
127
+
128
+ def run_fidelity_experiment(extractor, centered_concepts, concept_mean, device):
129
+ """
130
+ Fidelity experiment: compute fidelity across K={5,10,20} for COMP, NNOMP, top-k.
131
+ Matches paper's Fig. 3 experiment.
132
+ """
133
+ print("\n" + "=" * 80)
134
+ print("EXPERIMENT 1: Fidelity Analysis")
135
+ print("=" * 80)
136
+
137
+ K_values = [5, 10, 20]
138
+ methods = {
139
+ "COMP (Ξ»=0.3)": lambda v, G, K: comp(v, G, K=K, lambda_coh=0.3),
140
+ "NNOMP": lambda v, G, K: nnomp(v, G, K=K),
141
+ "Top-K": lambda v, G, K: top_k_selection(v, G, K=K),
142
+ }
143
+
144
+ results = {}
145
+
146
+ for method_name, method_fn in methods.items():
147
+ results[method_name] = {}
148
+ for K in K_values:
149
+ fidelities = []
150
+ print(f"\n {method_name}, K={K}:")
151
+
152
+ for layer_idx in ANALYZE_LAYERS:
153
+ W_VO_all = extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
154
+
155
+ for head_idx in range(N_HEADS):
156
+ W_VO_h = W_VO_all[head_idx]
157
+ U, sigma, Vt = extractor.svd_decompose(W_VO_h, top_k=N_SVS)
158
+ V_proj = extractor.project_to_feature_space(Vt)
159
+ V_centered = F.normalize(V_proj - concept_mean, dim=-1)
160
+
161
+ for sv_idx in range(N_SVS):
162
+ v_hat = V_centered[sv_idx]
163
+ coeffs, support = method_fn(v_hat, centered_concepts, K)
164
+ fid = compute_fidelity(v_hat, coeffs, support, centered_concepts)
165
+ fidelities.append(fid)
166
+
167
+ mean_fid = np.mean(fidelities)
168
+ std_fid = np.std(fidelities)
169
+ results[method_name][K] = {
170
+ "mean": mean_fid,
171
+ "std": std_fid,
172
+ "n": len(fidelities),
173
+ }
174
+ print(f" Mean fidelity: {mean_fid:.4f} Β± {std_fid:.4f} (n={len(fidelities)})")
175
+
176
+ return results
177
+
178
+
179
+ def run_monosemanticity_experiment(extractor, centered_concepts, concept_mean,
180
+ concept_pool, device):
181
+ """
182
+ Monosemanticity experiment: evaluate how coherent the concept sets are.
183
+ Uses intra-set cosine similarity as automated proxy for the LLM-as-judge.
184
+ Matches paper's Table 21 evaluation.
185
+ """
186
+ print("\n" + "=" * 80)
187
+ print("EXPERIMENT 2: Monosemanticity Analysis")
188
+ print("=" * 80)
189
+
190
+ K_values = [5, 10]
191
+ methods = {
192
+ "COMP (Ξ»=0.3)": lambda v, G, K: comp(v, G, K=K, lambda_coh=0.3),
193
+ "NNOMP": lambda v, G, K: nnomp(v, G, K=K),
194
+ "Top-K": lambda v, G, K: top_k_selection(v, G, K=K),
195
+ }
196
+
197
+ results = {}
198
+ detailed_examples = [] # For qualitative results
199
+
200
+ for method_name, method_fn in methods.items():
201
+ results[method_name] = {}
202
+ for K in K_values:
203
+ mono_scores = []
204
+ raw_sims = []
205
+
206
+ for layer_idx in ANALYZE_LAYERS:
207
+ W_VO_all = extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
208
+
209
+ for head_idx in range(N_HEADS):
210
+ W_VO_h = W_VO_all[head_idx]
211
+ U, sigma, Vt = extractor.svd_decompose(W_VO_h, top_k=N_SVS)
212
+ V_proj = extractor.project_to_feature_space(Vt)
213
+ V_centered = F.normalize(V_proj - concept_mean, dim=-1)
214
+
215
+ for sv_idx in range(N_SVS):
216
+ v_hat = V_centered[sv_idx]
217
+ coeffs, support = method_fn(v_hat, centered_concepts, K)
218
+
219
+ # Get the embeddings of selected concepts
220
+ selected_embs = centered_concepts[support]
221
+ score, mean_sim = compute_monosemanticity_score(selected_embs)
222
+ mono_scores.append(score)
223
+ raw_sims.append(mean_sim)
224
+
225
+ # Collect detailed examples for COMP K=5
226
+ if method_name == "COMP (Ξ»=0.3)" and K == 5:
227
+ fid = compute_fidelity(v_hat, coeffs, support, centered_concepts)
228
+ captions = [concept_pool.captions[idx] for idx in support]
229
+ coeff_vals = [coeffs[idx].item() for idx in support]
230
+ image_ids = None
231
+ if concept_pool.image_ids is not None:
232
+ image_ids = [concept_pool.image_ids[idx] for idx in support]
233
+ detailed_examples.append({
234
+ "layer": layer_idx,
235
+ "head": head_idx,
236
+ "sv_index": sv_idx,
237
+ "singular_value": sigma[sv_idx].item(),
238
+ "fidelity": fid,
239
+ "monosemanticity_score": score,
240
+ "mean_pairwise_sim": mean_sim,
241
+ "concepts": [
242
+ {"caption": c, "coefficient": w}
243
+ for c, w in zip(captions, coeff_vals)
244
+ ],
245
+ "image_ids": image_ids,
246
+ })
247
+
248
+ mean_mono = np.mean(mono_scores)
249
+ std_mono = np.std(mono_scores)
250
+ mean_raw = np.mean(raw_sims)
251
+ results[method_name][K] = {
252
+ "mean_score": mean_mono,
253
+ "std_score": std_mono,
254
+ "mean_pairwise_sim": mean_raw,
255
+ "n": len(mono_scores),
256
+ }
257
+ print(f" {method_name}, K={K}: "
258
+ f"mono={mean_mono:.2f}Β±{std_mono:.2f}, "
259
+ f"mean_sim={mean_raw:.4f}")
260
+
261
+ return results, detailed_examples
262
+
263
+
264
+ def select_qualitative_examples(detailed_examples, n=25):
265
+ """
266
+ Select ~25 diverse, high-quality qualitative examples.
267
+ Strategy: pick examples with high monosemanticity AND high fidelity,
268
+ spread across different layers and heads.
269
+ """
270
+ # Sort by combined quality: mono_score * fidelity * singular_value
271
+ for ex in detailed_examples:
272
+ ex["quality_score"] = (
273
+ ex["monosemanticity_score"] * ex["fidelity"] *
274
+ min(ex["singular_value"], 5.0) # Cap SV influence
275
+ )
276
+
277
+ sorted_examples = sorted(detailed_examples, key=lambda x: x["quality_score"], reverse=True)
278
+
279
+ # Ensure diversity: no more than 2 examples from same (layer, head)
280
+ selected = []
281
+ seen_heads = defaultdict(int)
282
+
283
+ for ex in sorted_examples:
284
+ key = (ex["layer"], ex["head"])
285
+ if seen_heads[key] < 2:
286
+ selected.append(ex)
287
+ seen_heads[key] += 1
288
+ if len(selected) >= n:
289
+ break
290
+
291
+ # If we don't have enough, relax constraint
292
+ if len(selected) < n:
293
+ for ex in sorted_examples:
294
+ if ex not in selected:
295
+ selected.append(ex)
296
+ if len(selected) >= n:
297
+ break
298
+
299
+ return selected[:n]
300
+
301
+
302
+ def generate_qualitative_markdown(examples, output_path):
303
+ """Generate a markdown file with qualitative results."""
304
+ lines = [
305
+ "# UniSITH Qualitative Results",
306
+ "",
307
+ "## DINOv2-Base Analysis β€” Selected Singular Vector Interpretations",
308
+ "",
309
+ f"**Model:** `facebook/dinov2-base` (12 heads, 768d, 12 layers)",
310
+ f"**Concept pool:** Recap-COCO-30K (30,504 captioned images)",
311
+ f"**Method:** COMP (Ξ»=0.3, K=5)",
312
+ f"**Layers analyzed:** {ANALYZE_LAYERS}",
313
+ "",
314
+ "Each entry shows one singular vector from an attention head, decomposed into",
315
+ "5 visual concepts from the image pool. The concepts are ranked by coefficient weight.",
316
+ "Captions are from COCO annotations and describe what visual content the attention",
317
+ "head encodes in that direction.",
318
+ "",
319
+ "---",
320
+ "",
321
+ ]
322
+
323
+ for i, ex in enumerate(examples, 1):
324
+ lines.append(f"### Example {i}: Layer {ex['layer']}, Head {ex['head']}, "
325
+ f"SV {ex['sv_index']}")
326
+ lines.append("")
327
+ lines.append(f"- **Singular value:** {ex['singular_value']:.4f}")
328
+ lines.append(f"- **Fidelity:** {ex['fidelity']:.4f}")
329
+ lines.append(f"- **Monosemanticity score:** {ex['monosemanticity_score']:.2f}/5.0")
330
+ lines.append(f"- **Mean pairwise similarity:** {ex['mean_pairwise_sim']:.4f}")
331
+ lines.append("")
332
+ lines.append("| Coefficient | Caption (Visual Concept) |")
333
+ lines.append("|---|---|")
334
+ for concept in ex["concepts"]:
335
+ lines.append(f"| {concept['coefficient']:.4f} | {concept['caption']} |")
336
+ lines.append("")
337
+
338
+ # Add COCO image IDs for reference
339
+ if ex.get("image_ids"):
340
+ ids_str = ", ".join(str(x) for x in ex["image_ids"])
341
+ lines.append(f"*COCO image IDs: {ids_str}*")
342
+ urls = [f"[{img_id}](http://images.cocodataset.org/val2014/COCO_val2014_{img_id:012d}.jpg)"
343
+ for img_id in ex["image_ids"]]
344
+ sep = " | "
345
+ lines.append(f"*Image links: {sep.join(urls)}*")
346
+ lines.append("")
347
+
348
+ lines.append("---")
349
+ lines.append("")
350
+
351
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
352
+ with open(output_path, "w") as f:
353
+ f.write("\n".join(lines))
354
+ print(f"Qualitative results saved to {output_path}")
355
+
356
+
357
+ def generate_experiment_report(fidelity_results, mono_results, output_path):
358
+ """Generate a markdown report of all experiments."""
359
+ lines = [
360
+ "# UniSITH Experiment Report",
361
+ "",
362
+ "## Setup",
363
+ "",
364
+ f"- **Model:** `facebook/dinov2-base` (12 heads Γ— 768d Γ— 12 layers)",
365
+ f"- **Concept pool:** Recap-COCO-30K (30,504 captioned images)",
366
+ f"- **Layers analyzed:** {ANALYZE_LAYERS} (last 4)",
367
+ f"- **Singular vectors per head:** {N_SVS}",
368
+ f"- **Total SVs analyzed:** {len(ANALYZE_LAYERS) * N_HEADS * N_SVS}",
369
+ "",
370
+ "---",
371
+ "",
372
+ "## Experiment 1: Fidelity Analysis",
373
+ "",
374
+ "Fidelity measures how well the sparse concept set reconstructs the original",
375
+ "singular vector (cosine similarity between original and reconstruction).",
376
+ "",
377
+ "| Method | K=5 | K=10 | K=20 |",
378
+ "|---|---|---|---|",
379
+ ]
380
+
381
+ for method_name, K_results in fidelity_results.items():
382
+ vals = []
383
+ for K in [5, 10, 20]:
384
+ r = K_results[K]
385
+ vals.append(f"{r['mean']:.4f} Β± {r['std']:.4f}")
386
+ lines.append(f"| {method_name} | {' | '.join(vals)} |")
387
+
388
+ lines.extend([
389
+ "",
390
+ "---",
391
+ "",
392
+ "## Experiment 2: Monosemanticity Analysis",
393
+ "",
394
+ "Monosemanticity measures how coherent each concept set is β€” whether the selected",
395
+ "concepts point to a single, unambiguous visual theme.",
396
+ "",
397
+ "We use mean pairwise cosine similarity among selected concept embeddings as an",
398
+ "automated proxy for the LLM-as-judge evaluation used in the original SITH paper.",
399
+ "The score is mapped to a 1-5 Likert scale.",
400
+ "",
401
+ "| Method | K=5 Score | K=5 Sim | K=10 Score | K=10 Sim |",
402
+ "|---|---|---|---|---|",
403
+ ])
404
+
405
+ for method_name, K_results in mono_results.items():
406
+ vals = []
407
+ for K in [5, 10]:
408
+ r = K_results[K]
409
+ vals.append(f"{r['mean_score']:.2f} Β± {r['std_score']:.2f}")
410
+ vals.append(f"{r['mean_pairwise_sim']:.4f}")
411
+ lines.append(f"| {method_name} | {' | '.join(vals)} |")
412
+
413
+ lines.extend([
414
+ "",
415
+ "### Interpretation",
416
+ "",
417
+ "- **COMP** achieves the best balance: high fidelity with high monosemanticity",
418
+ "- **Top-K** has high monosemanticity (by construction β€” all concepts are similar)",
419
+ " but lower fidelity (misses diverse aspects of the singular vector)",
420
+ "- **NNOMP** has high fidelity but lower monosemanticity (selects diverse but",
421
+ " potentially incoherent concepts)",
422
+ "",
423
+ "This mirrors the findings of the original SITH paper (Fig. 3).",
424
+ ])
425
+
426
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
427
+ with open(output_path, "w") as f:
428
+ f.write("\n".join(lines))
429
+ print(f"Experiment report saved to {output_path}")
430
+
431
+
432
+ def main():
433
+ parser = argparse.ArgumentParser()
434
+ parser.add_argument("--device", type=str, default="cuda")
435
+ args = parser.parse_args()
436
+
437
+ device = args.device
438
+ if device == "cuda" and not torch.cuda.is_available():
439
+ print("CUDA not available, falling back to CPU")
440
+ device = "cpu"
441
+
442
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
443
+ os.makedirs(CACHE_DIR, exist_ok=True)
444
+
445
+ start_time = time.time()
446
+
447
+ # ─── Step 1: Load model ───────────────────────────────────────────────────
448
+ print("=" * 80)
449
+ print("STEP 1: Loading DINOv2-base")
450
+ print("=" * 80)
451
+ model = AutoModel.from_pretrained(MODEL_NAME)
452
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
453
+ model.eval()
454
+ model = model.to(device)
455
+ print(f"Model loaded on {device}")
456
+
457
+ # ─── Step 2: Build concept pool (full 30K) ───────────────────────────────
458
+ print("\n" + "=" * 80)
459
+ print("STEP 2: Building concept pool (full 30K images)")
460
+ print("=" * 80)
461
+
462
+ cache_path = os.path.join(CACHE_DIR, "concept_pool_dinov2_base_30K.pt")
463
+
464
+ dataset = load_dataset("UCSC-VLAA/Recap-COCO-30K", split="train")
465
+ print(f"Dataset loaded: {len(dataset)} images")
466
+
467
+ pool = VisualConceptPool.from_dataset(
468
+ dataset=dataset,
469
+ model=model,
470
+ processor=processor,
471
+ architecture=ARCHITECTURE,
472
+ image_column="image",
473
+ caption_column="caption",
474
+ image_id_column="image_id",
475
+ batch_size=128,
476
+ max_concepts=None, # Use ALL 30K
477
+ device=device,
478
+ cache_path=cache_path,
479
+ )
480
+ print(f"Concept pool: {pool.num_concepts} concepts, dim={pool.embed_dim}")
481
+
482
+ elapsed = time.time() - start_time
483
+ print(f"Time so far: {elapsed:.0f}s")
484
+
485
+ # ─── Step 3: Prepare analyzer ─────────────────────────────────────────────
486
+ print("\n" + "=" * 80)
487
+ print("STEP 3: Preparing analyzer")
488
+ print("=" * 80)
489
+
490
+ extractor = WeightExtractor(model, ARCHITECTURE, N_HEADS, D_MODEL)
491
+ centered_concepts, concept_mean = pool.get_centered_embeddings()
492
+ centered_concepts = centered_concepts.to(device)
493
+ concept_mean = concept_mean.to(device)
494
+
495
+ # ─── Step 4: Fidelity experiment ──────────────────────────────────────────
496
+ fidelity_results = run_fidelity_experiment(
497
+ extractor, centered_concepts, concept_mean, device
498
+ )
499
+
500
+ # Save intermediate
501
+ with open(os.path.join(OUTPUT_DIR, "fidelity_results.json"), "w") as f:
502
+ json.dump(fidelity_results, f, indent=2)
503
+
504
+ elapsed = time.time() - start_time
505
+ print(f"\nFidelity experiment done. Time so far: {elapsed:.0f}s")
506
+
507
+ # ─── Step 5: Monosemanticity experiment ───────────────────────────────────
508
+ mono_results, detailed_examples = run_monosemanticity_experiment(
509
+ extractor, centered_concepts, concept_mean, pool, device
510
+ )
511
+
512
+ # Save intermediate
513
+ with open(os.path.join(OUTPUT_DIR, "monosemanticity_results.json"), "w") as f:
514
+ json.dump(mono_results, f, indent=2)
515
+
516
+ elapsed = time.time() - start_time
517
+ print(f"\nMonosemanticity experiment done. Time so far: {elapsed:.0f}s")
518
+
519
+ # ─── Step 6: Select and save qualitative examples ─────────────────────────
520
+ print("\n" + "=" * 80)
521
+ print("STEP 6: Generating qualitative results")
522
+ print("=" * 80)
523
+
524
+ qualitative = select_qualitative_examples(detailed_examples, n=25)
525
+
526
+ # Save raw JSON
527
+ with open(os.path.join(OUTPUT_DIR, "qualitative_examples.json"), "w") as f:
528
+ json.dump(qualitative, f, indent=2)
529
+
530
+ # Generate markdown
531
+ generate_qualitative_markdown(
532
+ qualitative,
533
+ os.path.join(OUTPUT_DIR, "qualitative_results.md")
534
+ )
535
+
536
+ # ─── Step 7: Generate full report ─────────────────────────────────────────
537
+ generate_experiment_report(
538
+ fidelity_results, mono_results,
539
+ os.path.join(OUTPUT_DIR, "experiment_report.md")
540
+ )
541
+
542
+ # ─── Step 8: Save full analysis results ───────────────────────────────────
543
+ print("\n" + "=" * 80)
544
+ print("STEP 8: Running full COMP K=5 analysis and saving results")
545
+ print("=" * 80)
546
+
547
+ analyzer = UniSITH(
548
+ model=model,
549
+ architecture=ARCHITECTURE,
550
+ n_heads=N_HEADS,
551
+ d_model=D_MODEL,
552
+ concept_pool=pool,
553
+ device=device,
554
+ )
555
+
556
+ full_results = analyzer.analyze_model(
557
+ layers=ANALYZE_LAYERS,
558
+ n_singular_vectors=N_SVS,
559
+ K=5,
560
+ lambda_coh=LAMBDA_COH,
561
+ method="comp",
562
+ )
563
+
564
+ UniSITH.save_results(full_results, os.path.join(OUTPUT_DIR, "full_analysis.json"))
565
+
566
+ total_time = time.time() - start_time
567
+ print(f"\n{'=' * 80}")
568
+ print(f"ALL EXPERIMENTS COMPLETE. Total time: {total_time:.0f}s ({total_time/60:.1f}min)")
569
+ print(f"Results saved in {OUTPUT_DIR}/")
570
+ print(f"{'=' * 80}")
571
+
572
+
573
+ if __name__ == "__main__":
574
+ main()