AbstractPhil commited on
Commit
7f9ee11
Β·
verified Β·
1 Parent(s): 7f7daa9

Create colab_deep_analysis.py

Browse files
Files changed (1) hide show
  1. colab_deep_analysis.py +506 -0
colab_deep_analysis.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # INTERNAL ANALYZER: CaptionBERT-8192
3
+ #
4
+ # Sees inside the model, not just the output. Five diagnostic lenses:
5
+ # 1. Spectral trajectories β€” eigenvalue evolution per layer
6
+ # 2. Effective dimensionality β€” how deeply each input is understood
7
+ # 3. Cross-layer divergence β€” where computation actually happens
8
+ # 4. Token influence β€” which input tokens drive the output
9
+ # 5. Neighborhood structure β€” local geometry at each layer
10
+ #
11
+ # Usage:
12
+ # analyzer = InternalAnalyzer(model, tokenizer)
13
+ # report = analyzer.analyze(["girl", "woman", "subtraction", "multiplication"])
14
+ # analyzer.print_report(report)
15
+ # analyzer.compare(report, "girl", "subtraction")
16
+ # ============================================================================
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import numpy as np
21
+ from collections import defaultdict
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+
26
+ class InternalAnalyzer:
27
+ def __init__(self, model, tokenizer, max_len=512):
28
+ self.model = model.to(DEVICE).eval()
29
+ self.tokenizer = tokenizer
30
+ self.max_len = max_len
31
+
32
+ # ══════════════════════════════════════════════════════════════
33
+ # CORE: Extract all layer representations
34
+ # ══════════════════════════════════════════════════════════════
35
+
36
+ @torch.no_grad()
37
+ def extract_layers(self, texts):
38
+ """Get per-layer mean-pooled representations for each input."""
39
+ if isinstance(texts, str):
40
+ texts = [texts]
41
+
42
+ inputs = self.tokenizer(
43
+ texts, max_length=self.max_len, padding="max_length",
44
+ truncation=True, return_tensors="pt").to(DEVICE)
45
+
46
+ outputs = self.model(
47
+ input_ids=inputs["input_ids"],
48
+ attention_mask=inputs["attention_mask"],
49
+ output_hidden_states=True)
50
+
51
+ mask = inputs["attention_mask"].unsqueeze(-1).float()
52
+ n_tokens = inputs["attention_mask"].sum(-1)
53
+
54
+ # Mean-pool each layer
55
+ layer_pooled = []
56
+ for h in outputs.hidden_states:
57
+ pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1)
58
+ layer_pooled.append(pooled.cpu())
59
+
60
+ return {
61
+ "texts": texts,
62
+ "layer_pooled": layer_pooled, # list of (B, 384) per layer
63
+ "layer_raw": outputs.hidden_states, # tuple of (B, L, 384) per layer
64
+ "final_embedding": outputs.last_hidden_state.cpu(), # (B, 768)
65
+ "attention_mask": inputs["attention_mask"].cpu(),
66
+ "n_tokens": n_tokens.cpu(),
67
+ }
68
+
69
+ # ══════════════════════════════════════════════════════════════
70
+ # 1. SPECTRAL TRAJECTORIES
71
+ # ══════════════════════════════════════════════════════════════
72
+
73
+ def spectral_trajectory(self, data):
74
+ """
75
+ Eigenvalue spectrum at each layer for each input.
76
+ Shows how the representation's internal structure evolves.
77
+ """
78
+ results = []
79
+ n_layers = len(data["layer_pooled"])
80
+ B = data["layer_pooled"][0].shape[0]
81
+
82
+ for b in range(B):
83
+ trajectory = []
84
+ for layer_idx in range(n_layers):
85
+ # For single vector: compute singular values of the
86
+ # raw token-level representation (before pooling)
87
+ h = data["layer_raw"][layer_idx][b].cpu().float() # (L, 384)
88
+ mask = data["attention_mask"][b]
89
+ n_real = mask.sum().int().item()
90
+ h = h[:n_real] # only real tokens
91
+
92
+ if n_real < 2:
93
+ trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0})
94
+ continue
95
+
96
+ # SVD of token representations
97
+ h_centered = h - h.mean(0, keepdim=True)
98
+ try:
99
+ S = torch.linalg.svdvals(h_centered)
100
+ except Exception:
101
+ trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0})
102
+ continue
103
+
104
+ # Normalized spectrum
105
+ S_norm = S / (S.sum() + 1e-12)
106
+
107
+ # Effective dimensionality (participation ratio)
108
+ eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12)
109
+
110
+ # Spectral entropy
111
+ S_pos = S_norm[S_norm > 1e-12]
112
+ entropy = -(S_pos * S_pos.log()).sum()
113
+
114
+ trajectory.append({
115
+ "spectrum": S[:20].tolist(), # top 20 singular values
116
+ "eff_dim": eff_dim.item(),
117
+ "entropy": entropy.item(),
118
+ "top1_ratio": (S[0] / (S.sum() + 1e-12)).item(),
119
+ })
120
+
121
+ results.append({
122
+ "text": data["texts"][b],
123
+ "trajectory": trajectory,
124
+ })
125
+
126
+ return results
127
+
128
+ # ══════════════════════════════════════════════════════════════
129
+ # 2. EFFECTIVE DIMENSIONALITY (output space)
130
+ # ══════════════════════════════════════════════════════════════
131
+
132
+ def effective_dimensionality(self, data, k_neighbors=50):
133
+ """
134
+ Local effective dimensionality around each embedding.
135
+ High = rich understanding. Low = surface-level placement.
136
+ """
137
+ embeddings = data["final_embedding"].float() # (B, 768)
138
+ B = embeddings.shape[0]
139
+
140
+ if B < k_neighbors + 1:
141
+ k_neighbors = max(B - 1, 2)
142
+
143
+ # Pairwise distances
144
+ sim = embeddings @ embeddings.T
145
+ results = []
146
+
147
+ for b in range(B):
148
+ # Get k nearest neighbors
149
+ sims = sim[b].clone()
150
+ sims[b] = -1 # exclude self
151
+ _, topk_idx = sims.topk(k_neighbors)
152
+ neighbors = embeddings[topk_idx] # (k, 768)
153
+
154
+ # Local PCA
155
+ centered = neighbors - neighbors.mean(0, keepdim=True)
156
+ try:
157
+ S = torch.linalg.svdvals(centered)
158
+ except Exception:
159
+ results.append({"eff_dim": 0, "local_variance": 0})
160
+ continue
161
+
162
+ # Participation ratio
163
+ eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12)
164
+
165
+ # How fast do eigenvalues decay?
166
+ S_norm = S / (S.sum() + 1e-12)
167
+ decay_rate = (S_norm[:5].sum() / S_norm.sum()).item()
168
+
169
+ results.append({
170
+ "text": data["texts"][b],
171
+ "eff_dim": eff_dim.item(),
172
+ "decay_rate": decay_rate, # high = concentrated, low = spread
173
+ "local_spread": centered.norm(dim=-1).mean().item(),
174
+ })
175
+
176
+ return results
177
+
178
+ # ══════════════════════════════════════════════════════════════
179
+ # 3. CROSS-LAYER DIVERGENCE
180
+ # ══════════════════════════════════════════════════════════════
181
+
182
+ def cross_layer_divergence(self, data):
183
+ """
184
+ How much does the representation change between layers?
185
+ High change = computation happening. Low change = pass-through.
186
+ """
187
+ results = []
188
+ n_layers = len(data["layer_pooled"])
189
+ B = data["layer_pooled"][0].shape[0]
190
+
191
+ for b in range(B):
192
+ profile = []
193
+ for i in range(n_layers - 1):
194
+ h_curr = data["layer_pooled"][i][b].float()
195
+ h_next = data["layer_pooled"][i + 1][b].float()
196
+
197
+ # Cosine between consecutive layers
198
+ cos = F.cosine_similarity(h_curr.unsqueeze(0),
199
+ h_next.unsqueeze(0)).item()
200
+ # L2 distance
201
+ l2 = (h_next - h_curr).norm().item()
202
+
203
+ # Direction change (how much the direction rotates)
204
+ h_curr_n = F.normalize(h_curr, dim=0)
205
+ h_next_n = F.normalize(h_next, dim=0)
206
+ angle = torch.acos(torch.clamp(
207
+ (h_curr_n * h_next_n).sum(), -1, 1)).item()
208
+
209
+ profile.append({
210
+ "layer": f"{i}β†’{i+1}",
211
+ "cosine": cos,
212
+ "l2_shift": l2,
213
+ "angle_rad": angle,
214
+ })
215
+
216
+ # Total path length through representation space
217
+ total_path = sum(p["l2_shift"] for p in profile)
218
+ # Where did most change happen?
219
+ max_shift_layer = max(range(len(profile)),
220
+ key=lambda i: profile[i]["l2_shift"])
221
+
222
+ results.append({
223
+ "text": data["texts"][b],
224
+ "profile": profile,
225
+ "total_path": total_path,
226
+ "max_shift_layer": max_shift_layer,
227
+ "input_output_cos": F.cosine_similarity(
228
+ data["layer_pooled"][0][b].unsqueeze(0).float(),
229
+ data["layer_pooled"][-1][b].unsqueeze(0).float()
230
+ ).item(),
231
+ })
232
+
233
+ return results
234
+
235
+ # ══════════════════════════════════════════════════════════════
236
+ # 4. TOKEN INFLUENCE (gradient-based)
237
+ # ══════════════════════════════════════════════════════════════
238
+
239
+ def token_influence(self, texts):
240
+ """
241
+ Which tokens influence the output most?
242
+ Uses gradient of output norm w.r.t. input embeddings.
243
+ """
244
+ if isinstance(texts, str):
245
+ texts = [texts]
246
+
247
+ results = []
248
+ for text in texts:
249
+ inputs = self.tokenizer(
250
+ [text], max_length=self.max_len, padding="max_length",
251
+ truncation=True, return_tensors="pt").to(DEVICE)
252
+
253
+ # Get embedding layer output with gradients
254
+ input_ids = inputs["input_ids"]
255
+ attention_mask = inputs["attention_mask"]
256
+ n_real = attention_mask.sum().item()
257
+
258
+ # Hook into embedding
259
+ emb = self.model.token_emb(input_ids) + \
260
+ self.model.pos_emb(torch.arange(input_ids.shape[1],
261
+ device=DEVICE).unsqueeze(0))
262
+ emb = self.model.emb_drop(self.model.emb_norm(emb))
263
+ emb.retain_grad()
264
+
265
+ # Forward through encoder
266
+ kpm = ~attention_mask.bool()
267
+ x = emb
268
+ for layer in self.model.encoder.layers:
269
+ x = layer(x, src_key_padding_mask=kpm)
270
+
271
+ # Pool and project
272
+ mask = attention_mask.unsqueeze(-1).float()
273
+ pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
274
+ output = F.normalize(self.model.output_proj(pooled), dim=-1)
275
+
276
+ # Gradient of output norm w.r.t embeddings
277
+ output.sum().backward()
278
+ grad = emb.grad[0].cpu()
279
+
280
+ # Per-token influence = gradient norm
281
+ influence = grad.norm(dim=-1)[:int(n_real)] # only real tokens
282
+ influence = influence / (influence.sum() + 1e-12) # normalize
283
+
284
+ # Decode tokens
285
+ token_ids = input_ids[0][:int(n_real)].cpu().tolist()
286
+ tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
287
+
288
+ results.append({
289
+ "text": text,
290
+ "tokens": tokens,
291
+ "influence": influence.tolist(),
292
+ "top_tokens": sorted(zip(tokens, influence.tolist()),
293
+ key=lambda x: -x[1])[:10],
294
+ "concentration": (influence.max() / influence.mean()).item(),
295
+ })
296
+
297
+ self.model.zero_grad()
298
+
299
+ return results
300
+
301
+ # ══════════════════════════════════════════════════════════════
302
+ # 5. FULL ANALYSIS
303
+ # ══════════════════════════════════════════════════════════════
304
+
305
+ def analyze(self, texts):
306
+ """Run all analyses on a set of texts."""
307
+ if isinstance(texts, str):
308
+ texts = [texts]
309
+
310
+ print(f" Analyzing {len(texts)} inputs...")
311
+
312
+ data = self.extract_layers(texts)
313
+ spectral = self.spectral_trajectory(data)
314
+ eff_dim = self.effective_dimensionality(data)
315
+ divergence = self.cross_layer_divergence(data)
316
+ influence = self.token_influence(texts)
317
+
318
+ report = {}
319
+ for i, text in enumerate(texts):
320
+ report[text] = {
321
+ "embedding": data["final_embedding"][i],
322
+ "n_tokens": data["n_tokens"][i].item(),
323
+ "spectral": spectral[i],
324
+ "eff_dim": eff_dim[i] if i < len(eff_dim) else {},
325
+ "divergence": divergence[i],
326
+ "influence": influence[i],
327
+ }
328
+
329
+ return report
330
+
331
+ # ══════════════════════════════════════════════════════════════
332
+ # PRINTING
333
+ # ══════════════════════════════════════════════════════════════
334
+
335
+ def print_report(self, report):
336
+ """Print full analysis report."""
337
+ print(f"\n{'='*70}")
338
+ print("INTERNAL ANALYSIS REPORT")
339
+ print(f"{'='*70}")
340
+
341
+ # Summary table
342
+ print(f"\n {'Text':<25} {'Tokens':>6} {'EffDim':>7} {'Path':>7} "
343
+ f"{'MaxShift':>9} {'InOutCos':>8} {'Concentrate':>11}")
344
+ print(f" {'-'*75}")
345
+
346
+ for text, r in report.items():
347
+ label = text[:24]
348
+ ed = r["eff_dim"].get("eff_dim", 0)
349
+ tp = r["divergence"]["total_path"]
350
+ ms = r["divergence"]["max_shift_layer"]
351
+ ioc = r["divergence"]["input_output_cos"]
352
+ conc = r["influence"]["concentration"]
353
+ print(f" {label:<25} {r['n_tokens']:>6} {ed:>7.1f} {tp:>7.2f} "
354
+ f" layer {ms:>2} {ioc:>7.3f} {conc:>10.1f}")
355
+
356
+ # Spectral evolution
357
+ print(f"\n SPECTRAL TRAJECTORY (effective dim per layer):")
358
+ print(f" {'Text':<25}", end="")
359
+ n_layers = len(next(iter(report.values()))["spectral"]["trajectory"])
360
+ for i in range(n_layers):
361
+ print(f" L{i:>2}", end="")
362
+ print()
363
+ print(f" {'-'*75}")
364
+
365
+ for text, r in report.items():
366
+ label = text[:24]
367
+ print(f" {label:<25}", end="")
368
+ for step in r["spectral"]["trajectory"]:
369
+ ed = step.get("eff_dim", 0)
370
+ print(f" {ed:>4.0f}", end="")
371
+ print()
372
+
373
+ # Spectral entropy per layer
374
+ print(f"\n SPECTRAL ENTROPY (information content per layer):")
375
+ print(f" {'Text':<25}", end="")
376
+ for i in range(n_layers):
377
+ print(f" L{i:>2}", end="")
378
+ print()
379
+ print(f" {'-'*75}")
380
+
381
+ for text, r in report.items():
382
+ label = text[:24]
383
+ print(f" {label:<25}", end="")
384
+ for step in r["spectral"]["trajectory"]:
385
+ ent = step.get("entropy", 0)
386
+ print(f" {ent:>4.1f}", end="")
387
+ print()
388
+
389
+ # Cross-layer divergence profiles
390
+ print(f"\n COMPUTATION PROFILE (L2 shift between layers):")
391
+ print(f" {'Text':<25}", end="")
392
+ for i in range(n_layers - 1):
393
+ print(f" {i}β†’{i+1:>2}", end="")
394
+ print()
395
+ print(f" {'-'*75}")
396
+
397
+ for text, r in report.items():
398
+ label = text[:24]
399
+ print(f" {label:<25}", end="")
400
+ for step in r["divergence"]["profile"]:
401
+ print(f" {step['l2_shift']:>4.1f}", end="")
402
+ print()
403
+
404
+ # Token influence for each input
405
+ print(f"\n TOKEN INFLUENCE (top contributing tokens):")
406
+ for text, r in report.items():
407
+ top = r["influence"]["top_tokens"][:5]
408
+ tok_str = " ".join(f"{t}={v:.3f}" for t, v in top)
409
+ print(f" {text[:40]:<42} {tok_str}")
410
+
411
+ def compare(self, report, text_a, text_b):
412
+ """Compare internal representations of two specific inputs."""
413
+ a = report[text_a]
414
+ b = report[text_b]
415
+
416
+ cos = F.cosine_similarity(
417
+ a["embedding"].unsqueeze(0),
418
+ b["embedding"].unsqueeze(0)).item()
419
+
420
+ print(f"\n{'='*70}")
421
+ print(f"COMPARISON: '{text_a}' vs '{text_b}'")
422
+ print(f"{'='*70}")
423
+ print(f" Output cosine: {cos:.4f}")
424
+ print(f" Tokens: {a['n_tokens']} vs {b['n_tokens']}")
425
+
426
+ # Effective dim comparison
427
+ ed_a = a["eff_dim"].get("eff_dim", 0)
428
+ ed_b = b["eff_dim"].get("eff_dim", 0)
429
+ print(f" Effective dim: {ed_a:.1f} vs {ed_b:.1f} (Ξ”={abs(ed_a-ed_b):.1f})")
430
+
431
+ # Path comparison
432
+ pa = a["divergence"]["total_path"]
433
+ pb = b["divergence"]["total_path"]
434
+ print(f" Total path: {pa:.2f} vs {pb:.2f} (Ξ”={abs(pa-pb):.2f})")
435
+
436
+ # Layer-by-layer spectral comparison
437
+ print(f"\n Effective dim trajectory:")
438
+ print(f" {'Layer':<8} {'A':>8} {'B':>8} {'Ξ”':>8}")
439
+ traj_a = a["spectral"]["trajectory"]
440
+ traj_b = b["spectral"]["trajectory"]
441
+ for i in range(len(traj_a)):
442
+ ea = traj_a[i].get("eff_dim", 0)
443
+ eb = traj_b[i].get("eff_dim", 0)
444
+ print(f" L{i:<6} {ea:>8.1f} {eb:>8.1f} {abs(ea-eb):>8.1f}")
445
+
446
+ # Divergence profile comparison
447
+ print(f"\n Computation profile (L2 shift):")
448
+ print(f" {'Transition':<10} {'A':>8} {'B':>8} {'Ξ”':>8}")
449
+ for i in range(len(a["divergence"]["profile"])):
450
+ sa = a["divergence"]["profile"][i]["l2_shift"]
451
+ sb = b["divergence"]["profile"][i]["l2_shift"]
452
+ label = a["divergence"]["profile"][i]["layer"]
453
+ print(f" {label:<10} {sa:>8.2f} {sb:>8.2f} {abs(sa-sb):>8.2f}")
454
+
455
+ # Token influence comparison
456
+ print(f"\n Top tokens:")
457
+ print(f" A: {' '.join(f'{t}={v:.3f}' for t,v in a['influence']['top_tokens'][:5])}")
458
+ print(f" B: {' '.join(f'{t}={v:.3f}' for t,v in b['influence']['top_tokens'][:5])}")
459
+
460
+
461
+ # ══════════════════════════════════════════════════════════════════
462
+ # RUN
463
+ # ══════════════════════════════════════════════════════════════════
464
+
465
+ if __name__ == "__main__":
466
+ from transformers import AutoModel, AutoTokenizer
467
+
468
+ REPO_ID = "AbstractPhil/geolip-captionbert-8192"
469
+ print("Loading model...")
470
+ model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True)
471
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
472
+
473
+ analyzer = InternalAnalyzer(model, tokenizer)
474
+
475
+ # Test words spanning known-domain and unknown-domain
476
+ test_words = [
477
+ # Known domain (captions)
478
+ "girl",
479
+ "woman",
480
+ "dog",
481
+ "sunset",
482
+ "painting",
483
+ # Unknown domain (abstract)
484
+ "subtraction",
485
+ "multiplication",
486
+ "prophetic",
487
+ "differential",
488
+ "adjacency",
489
+ # Phrases
490
+ "a girl sitting near a window",
491
+ "a dog playing on the beach",
492
+ "the differential equation of motion",
493
+ ]
494
+
495
+ report = analyzer.analyze(test_words)
496
+ analyzer.print_report(report)
497
+
498
+ # Direct comparisons
499
+ analyzer.compare(report, "girl", "woman")
500
+ analyzer.compare(report, "girl", "subtraction")
501
+ analyzer.compare(report, "a girl sitting near a window",
502
+ "the differential equation of motion")
503
+
504
+ print(f"\n{'='*70}")
505
+ print("DONE")
506
+ print(f"{'='*70}")