lorenzovaquero commited on
Commit
2f74fc4
·
verified ·
1 Parent(s): 96cd4af

Add UniSITH source code: weight_extraction, concept_pool, comp, unisith

Browse files
unimodal_sith/__pycache__/comp.cpython-312.pyc ADDED
Binary file (5.88 kB). View file
 
unimodal_sith/__pycache__/concept_pool.cpython-312.pyc ADDED
Binary file (8.44 kB). View file
 
unimodal_sith/__pycache__/unisith.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
unimodal_sith/__pycache__/weight_extraction.cpython-312.pyc ADDED
Binary file (16.6 kB). View file
 
unimodal_sith/comp.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ COMP: Coherent Orthogonal Matching Pursuit
3
+
4
+ Adapted from SITH (Vaquero et al., 2025), Algorithm 1.
5
+
6
+ Given a singular vector v_hat and a concept dictionary Gamma_hat, COMP finds
7
+ a sparse, semantically coherent combination of K concepts that best
8
+ approximates v_hat.
9
+
10
+ This implementation works with both text concept embeddings (original SITH)
11
+ and image concept embeddings (UniSITH).
12
+ """
13
+
14
+ import torch
15
+ import numpy as np
16
+ from scipy.optimize import nnls
17
+ from typing import List, Tuple, Optional
18
+
19
+
20
+ def comp(
21
+ v_hat: torch.Tensor,
22
+ Gamma_hat: torch.Tensor,
23
+ K: int = 5,
24
+ lambda_coh: float = 0.3,
25
+ ) -> Tuple[torch.Tensor, List[int]]:
26
+ """
27
+ Coherent Orthogonal Matching Pursuit (COMP).
28
+
29
+ Extends Non-Negative Orthogonal Matching Pursuit (NNOMP) by incorporating
30
+ a coherence term that encourages semantically coherent concept selections.
31
+
32
+ Args:
33
+ v_hat: [d] projected singular vector (L2-normalized)
34
+ Gamma_hat: [C, d] concept embedding matrix (L2-normalized rows)
35
+ K: Sparsity level (number of concepts to select)
36
+ lambda_coh: Coherence weight (λ in the paper, default 0.3)
37
+
38
+ Returns:
39
+ c: [C] sparse coefficient vector (non-negative)
40
+ support: List of K selected concept indices
41
+ """
42
+ C, d = Gamma_hat.shape
43
+ device = v_hat.device
44
+
45
+ # Move to CPU for scipy nnls
46
+ v_hat_np = v_hat.cpu().numpy().astype(np.float64)
47
+ Gamma_np = Gamma_hat.cpu().numpy().astype(np.float64)
48
+
49
+ # Initialize
50
+ r = v_hat_np.copy() # Residual
51
+ S = [] # Support set (selected concept indices)
52
+ c = np.zeros(C)
53
+
54
+ # Precompute concept-concept similarity matrix (for coherence)
55
+ # Only compute upper triangle for efficiency - but we'll compute on the fly
56
+ # since C can be very large
57
+
58
+ for k in range(K):
59
+ # Step 1: Compute correlations with residual
60
+ s_res = Gamma_np @ r # [C]
61
+
62
+ # Step 2: Compute coherence scores
63
+ s_coh = np.zeros(C)
64
+ if len(S) > 0:
65
+ # Average similarity of each candidate to already-selected concepts
66
+ S_embeddings = Gamma_np[S] # [|S|, d]
67
+ # Similarity of all concepts to selected ones
68
+ sim_to_selected = Gamma_np @ S_embeddings.T # [C, |S|]
69
+ s_coh = sim_to_selected.mean(axis=1) # [C]
70
+ # Zero out already selected
71
+ for idx in S:
72
+ s_coh[idx] = -np.inf
73
+
74
+ # Step 3: Combined score
75
+ s_final = s_res + lambda_coh * s_coh
76
+
77
+ # Mask already selected concepts
78
+ for idx in S:
79
+ s_final[idx] = -np.inf
80
+
81
+ # Step 4: Greedy selection
82
+ j_k = int(np.argmax(s_final))
83
+ S.append(j_k)
84
+
85
+ # Step 5: Non-negative least squares on current support
86
+ G_S = Gamma_np[S].T # [d, |S|] - columns are selected concept embeddings
87
+ c_S, _ = nnls(G_S, v_hat_np) # min ||v_hat - G_S @ c_S||^2, c_S >= 0
88
+
89
+ # Step 6: Update residual
90
+ r = v_hat_np - G_S @ c_S
91
+
92
+ # Construct final coefficient vector
93
+ c = np.zeros(C)
94
+ for i, j in enumerate(S):
95
+ c[j] = c_S[i]
96
+
97
+ return torch.tensor(c, dtype=torch.float32, device=device), S
98
+
99
+
100
+ def comp_batch(
101
+ V_hat: torch.Tensor,
102
+ Gamma_hat: torch.Tensor,
103
+ K: int = 5,
104
+ lambda_coh: float = 0.3,
105
+ ) -> Tuple[torch.Tensor, List[List[int]]]:
106
+ """
107
+ Apply COMP to multiple singular vectors.
108
+
109
+ Args:
110
+ V_hat: [n, d] batch of projected singular vectors
111
+ Gamma_hat: [C, d] concept embedding matrix
112
+ K: Sparsity level
113
+ lambda_coh: Coherence weight
114
+
115
+ Returns:
116
+ C_mat: [n, C] coefficient matrix
117
+ supports: List of n support sets
118
+ """
119
+ n = V_hat.shape[0]
120
+ C = Gamma_hat.shape[0]
121
+
122
+ C_mat = torch.zeros(n, C, device=V_hat.device)
123
+ supports = []
124
+
125
+ for i in range(n):
126
+ c_i, support_i = comp(V_hat[i], Gamma_hat, K=K, lambda_coh=lambda_coh)
127
+ C_mat[i] = c_i
128
+ supports.append(support_i)
129
+
130
+ return C_mat, supports
131
+
132
+
133
+ def top_k_selection(
134
+ v_hat: torch.Tensor,
135
+ Gamma_hat: torch.Tensor,
136
+ K: int = 5,
137
+ ) -> Tuple[torch.Tensor, List[int]]:
138
+ """
139
+ Simple top-K selection baseline: pick the K most similar concepts.
140
+
141
+ Args:
142
+ v_hat: [d] projected singular vector
143
+ Gamma_hat: [C, d] concept embedding matrix
144
+ K: Number of concepts to select
145
+
146
+ Returns:
147
+ c: [C] coefficient vector (similarity scores for top-K, 0 elsewhere)
148
+ support: List of K selected concept indices
149
+ """
150
+ similarities = Gamma_hat @ v_hat # [C]
151
+ top_k_vals, top_k_idx = torch.topk(similarities, K)
152
+
153
+ c = torch.zeros(Gamma_hat.shape[0], device=v_hat.device)
154
+ support = top_k_idx.tolist()
155
+ for i, idx in enumerate(support):
156
+ c[idx] = max(0, top_k_vals[i].item()) # Non-negative
157
+
158
+ return c, support
unimodal_sith/concept_pool.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Concept Pool for UniSITH.
3
+
4
+ Instead of text concepts (ConceptNet strings + CLIP text encoder),
5
+ we use captioned images as the concept pool.
6
+
7
+ Each concept is an image from a captioned dataset, and the corresponding
8
+ caption provides human-interpretable meaning.
9
+
10
+ The concept embeddings are computed by encoding each image through the
11
+ same unimodal vision model being analyzed.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ from typing import Dict, List, Optional, Tuple
18
+ from PIL import Image
19
+ from tqdm import tqdm
20
+ import os
21
+ import json
22
+
23
+
24
+ class VisualConceptPool:
25
+ """
26
+ A pool of visual concepts, each represented by:
27
+ - An image embedding (computed by the model being analyzed)
28
+ - A caption (for human interpretability)
29
+ - Optionally, the original image
30
+
31
+ Analogous to Γ = {γ_1, ..., γ_C} in SITH, but each γ_i is an image
32
+ embedding rather than a text embedding.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ embeddings: torch.Tensor,
38
+ captions: List[str],
39
+ image_ids: Optional[List[int]] = None,
40
+ metadata: Optional[Dict] = None,
41
+ ):
42
+ """
43
+ Args:
44
+ embeddings: [C, d] tensor of L2-normalized concept embeddings
45
+ captions: List of C caption strings
46
+ image_ids: Optional list of C image IDs for retrieval
47
+ metadata: Optional metadata dict
48
+ """
49
+ assert embeddings.shape[0] == len(captions), \
50
+ f"Embeddings ({embeddings.shape[0]}) and captions ({len(captions)}) must match"
51
+
52
+ self.embeddings = embeddings # [C, d]
53
+ self.captions = captions
54
+ self.image_ids = image_ids
55
+ self.metadata = metadata or {}
56
+ self.num_concepts = len(captions)
57
+ self.embed_dim = embeddings.shape[1]
58
+
59
+ @classmethod
60
+ def from_dataset(
61
+ cls,
62
+ dataset,
63
+ model,
64
+ processor,
65
+ architecture: str,
66
+ image_column: str = "image",
67
+ caption_column: str = "caption",
68
+ image_id_column: str = "image_id",
69
+ batch_size: int = 64,
70
+ max_concepts: Optional[int] = None,
71
+ device: str = "cpu",
72
+ cache_path: Optional[str] = None,
73
+ ) -> "VisualConceptPool":
74
+ """
75
+ Build a concept pool from a HuggingFace dataset.
76
+
77
+ Args:
78
+ dataset: HF dataset with image and caption columns
79
+ model: Vision model (HuggingFace transformers)
80
+ processor: Image processor/transform
81
+ architecture: Model architecture type
82
+ image_column: Column name for images
83
+ caption_column: Column name for captions
84
+ image_id_column: Column name for image IDs
85
+ batch_size: Batch size for encoding
86
+ max_concepts: Max number of concepts to use
87
+ device: Device for computation
88
+ cache_path: If set, cache embeddings to/from this path
89
+ """
90
+ # Check for cached embeddings
91
+ if cache_path and os.path.exists(cache_path):
92
+ print(f"Loading cached concept pool from {cache_path}")
93
+ return cls.load(cache_path)
94
+
95
+ if max_concepts is not None:
96
+ dataset = dataset.select(range(min(max_concepts, len(dataset))))
97
+
98
+ captions = dataset[caption_column]
99
+ image_ids = None
100
+ if image_id_column in dataset.column_names:
101
+ image_ids = dataset[image_id_column]
102
+
103
+ # Encode all images
104
+ model = model.to(device)
105
+ model.eval()
106
+
107
+ all_embeddings = []
108
+
109
+ print(f"Encoding {len(dataset)} concept images...")
110
+ for i in tqdm(range(0, len(dataset), batch_size)):
111
+ batch_end = min(i + batch_size, len(dataset))
112
+ batch_images = [dataset[j][image_column] for j in range(i, batch_end)]
113
+
114
+ # Ensure images are RGB
115
+ batch_images = [img.convert("RGB") if img.mode != "RGB" else img for img in batch_images]
116
+
117
+ # Process images
118
+ inputs = processor(images=batch_images, return_tensors="pt").to(device)
119
+
120
+ with torch.no_grad():
121
+ if architecture == "dinov2":
122
+ outputs = model(**inputs)
123
+ embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
124
+ elif architecture == "clip":
125
+ # For CLIP, get the vision features
126
+ outputs = model.vision_model(**inputs)
127
+ # Get CLS token, apply post-layernorm
128
+ pooled = outputs.pooler_output # Already pooled + post-LN
129
+ # Apply visual projection
130
+ embeddings = model.visual_projection(pooled)
131
+ elif architecture == "vit":
132
+ outputs = model(**inputs)
133
+ embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
134
+
135
+ # L2 normalize
136
+ embeddings = F.normalize(embeddings, dim=-1)
137
+ all_embeddings.append(embeddings.cpu())
138
+
139
+ embeddings = torch.cat(all_embeddings, dim=0)
140
+
141
+ pool = cls(
142
+ embeddings=embeddings,
143
+ captions=captions,
144
+ image_ids=image_ids,
145
+ metadata={
146
+ "architecture": architecture,
147
+ "num_concepts": len(captions),
148
+ "embed_dim": embeddings.shape[1],
149
+ },
150
+ )
151
+
152
+ # Cache if requested
153
+ if cache_path:
154
+ pool.save(cache_path)
155
+
156
+ return pool
157
+
158
+ def get_centered_embeddings(self) -> torch.Tensor:
159
+ """
160
+ Return embeddings after mean-centering and re-normalization.
161
+
162
+ This is analogous to the modality gap correction in SITH (Eq. 18-19),
163
+ but for unimodal models we center within the image embedding distribution
164
+ to ensure the concept embeddings are centered around the origin.
165
+
166
+ This is important for matching with singular vectors, which themselves
167
+ are zero-centered directions.
168
+ """
169
+ mu = self.embeddings.mean(dim=0, keepdim=True) # [1, d]
170
+ centered = self.embeddings - mu # [C, d]
171
+ centered = F.normalize(centered, dim=-1) # Re-normalize
172
+ return centered, mu
173
+
174
+ def save(self, path: str):
175
+ """Save concept pool to disk."""
176
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
177
+ torch.save({
178
+ "embeddings": self.embeddings,
179
+ "captions": self.captions,
180
+ "image_ids": self.image_ids,
181
+ "metadata": self.metadata,
182
+ }, path)
183
+ print(f"Saved concept pool to {path}")
184
+
185
+ @classmethod
186
+ def load(cls, path: str) -> "VisualConceptPool":
187
+ """Load concept pool from disk."""
188
+ data = torch.load(path, weights_only=False)
189
+ return cls(
190
+ embeddings=data["embeddings"],
191
+ captions=data["captions"],
192
+ image_ids=data.get("image_ids"),
193
+ metadata=data.get("metadata", {}),
194
+ )
unimodal_sith/unisith.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UniSITH: Unimodal Semantic Inspection of Transformer Heads
3
+
4
+ Main analysis class that orchestrates:
5
+ 1. Weight extraction (W_VO matrices from attention heads)
6
+ 2. SVD decomposition (finding principal directions)
7
+ 3. Projection to feature space
8
+ 4. Concept attribution via COMP (matching to visual concepts)
9
+ 5. Model editing (amplifying/suppressing concepts)
10
+
11
+ Key difference from original SITH:
12
+ - Works with ANY ViT (not just CLIP)
13
+ - Uses captioned images as concept pool (not text from ConceptNet)
14
+ - Captions provide human interpretability
15
+ - No cross-modal projection needed (same model encodes both the weights and concepts)
16
+ """
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import numpy as np
21
+ from typing import Dict, List, Optional, Tuple, Any
22
+ from dataclasses import dataclass, field
23
+ import json
24
+ import os
25
+
26
+ from .weight_extraction import WeightExtractor
27
+ from .concept_pool import VisualConceptPool
28
+ from .comp import comp, comp_batch, top_k_selection
29
+
30
+
31
+ @dataclass
32
+ class SingularVectorInterpretation:
33
+ """Interpretation of a single singular vector."""
34
+ layer_idx: int
35
+ head_idx: int
36
+ sv_idx: int
37
+ singular_value: float
38
+ concepts: List[str] # Captions of matched concepts
39
+ concept_indices: List[int] # Indices into concept pool
40
+ coefficients: List[float] # COMP coefficients
41
+ fidelity: float # Cosine similarity between original and reconstruction
42
+ image_ids: Optional[List[int]] = None # IDs for retrieving original images
43
+
44
+ def to_dict(self) -> Dict:
45
+ return {
46
+ "layer": self.layer_idx,
47
+ "head": self.head_idx,
48
+ "sv_index": self.sv_idx,
49
+ "singular_value": self.singular_value,
50
+ "concepts": [
51
+ {"caption": c, "coefficient": w, "concept_idx": idx}
52
+ for c, w, idx in zip(self.concepts, self.coefficients, self.concept_indices)
53
+ ],
54
+ "fidelity": self.fidelity,
55
+ "image_ids": self.image_ids,
56
+ }
57
+
58
+ def __repr__(self) -> str:
59
+ lines = [f"Layer {self.layer_idx}, Head {self.head_idx}, SV {self.sv_idx} "
60
+ f"(σ={self.singular_value:.4f}, fidelity={self.fidelity:.4f})"]
61
+ for c, w in zip(self.concepts, self.coefficients):
62
+ lines.append(f" [{w:.4f}] {c}")
63
+ return "\n".join(lines)
64
+
65
+
66
+ @dataclass
67
+ class HeadInterpretation:
68
+ """Full interpretation of an attention head."""
69
+ layer_idx: int
70
+ head_idx: int
71
+ singular_vectors: List[SingularVectorInterpretation]
72
+
73
+ def to_dict(self) -> Dict:
74
+ return {
75
+ "layer": self.layer_idx,
76
+ "head": self.head_idx,
77
+ "singular_vectors": [sv.to_dict() for sv in self.singular_vectors],
78
+ }
79
+
80
+ def __repr__(self) -> str:
81
+ lines = [f"=== Layer {self.layer_idx}, Head {self.head_idx} ==="]
82
+ for sv in self.singular_vectors:
83
+ lines.append(str(sv))
84
+ lines.append("")
85
+ return "\n".join(lines)
86
+
87
+
88
+ class UniSITH:
89
+ """
90
+ Unimodal Semantic Inspection of Transformer Heads.
91
+
92
+ Analyzes the internal representations of ViT attention heads by:
93
+ 1. Decomposing W_VO matrices via SVD
94
+ 2. Projecting singular vectors to the model's feature space
95
+ 3. Attributing visual concepts from a captioned image pool
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ model: torch.nn.Module,
101
+ architecture: str,
102
+ n_heads: int,
103
+ d_model: int,
104
+ concept_pool: VisualConceptPool,
105
+ device: str = "cpu",
106
+ ):
107
+ """
108
+ Args:
109
+ model: Vision transformer model
110
+ architecture: One of "dinov2", "clip", "vit"
111
+ n_heads: Number of attention heads
112
+ d_model: Hidden dimension
113
+ concept_pool: Visual concept pool with embeddings and captions
114
+ device: Computation device
115
+ """
116
+ self.model = model
117
+ self.architecture = architecture
118
+ self.device = device
119
+ self.concept_pool = concept_pool
120
+
121
+ self.extractor = WeightExtractor(model, architecture, n_heads, d_model)
122
+ self.n_heads = n_heads
123
+ self.d_model = d_model
124
+
125
+ # Precompute centered concept embeddings
126
+ self.centered_concepts, self.concept_mean = concept_pool.get_centered_embeddings()
127
+ self.centered_concepts = self.centered_concepts.to(device)
128
+ self.concept_mean = self.concept_mean.to(device)
129
+
130
+ def analyze_head(
131
+ self,
132
+ layer_idx: int,
133
+ head_idx: int,
134
+ n_singular_vectors: int = 5,
135
+ K: int = 5,
136
+ lambda_coh: float = 0.3,
137
+ method: str = "comp",
138
+ ) -> HeadInterpretation:
139
+ """
140
+ Analyze a single attention head: decompose its W_VO matrix and
141
+ interpret the top singular vectors.
142
+
143
+ Args:
144
+ layer_idx: Transformer layer index
145
+ head_idx: Attention head index
146
+ n_singular_vectors: Number of top singular vectors to interpret
147
+ K: Number of concepts per singular vector
148
+ lambda_coh: COMP coherence weight
149
+ method: "comp" or "top_k"
150
+
151
+ Returns:
152
+ HeadInterpretation with concept attributions for each singular vector
153
+ """
154
+ # Step 1: Extract W_VO and decompose via SVD
155
+ W_VO_all = self.extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
156
+ W_VO_h = W_VO_all[head_idx] # [D, D]
157
+
158
+ U, sigma, Vt = self.extractor.svd_decompose(W_VO_h, top_k=n_singular_vectors)
159
+ # U: [D, n_sv], sigma: [n_sv], Vt: [n_sv, D]
160
+
161
+ # Step 2: Project right singular vectors to feature space
162
+ V_projected = self.extractor.project_to_feature_space(Vt) # [n_sv, d_out]
163
+
164
+ # Step 3: Center the projected vectors (analogous to modality gap correction)
165
+ V_centered = V_projected - self.concept_mean
166
+ V_centered = F.normalize(V_centered, dim=-1)
167
+
168
+ # Step 4: Attribute concepts via COMP (or top-k)
169
+ sv_interpretations = []
170
+ for i in range(n_singular_vectors):
171
+ v_hat = V_centered[i] # [d_out]
172
+
173
+ if method == "comp":
174
+ coeffs, support = comp(
175
+ v_hat, self.centered_concepts, K=K, lambda_coh=lambda_coh
176
+ )
177
+ elif method == "top_k":
178
+ coeffs, support = top_k_selection(
179
+ v_hat, self.centered_concepts, K=K
180
+ )
181
+ else:
182
+ raise ValueError(f"Unknown method: {method}")
183
+
184
+ # Extract concept captions and coefficients
185
+ concept_captions = [self.concept_pool.captions[idx] for idx in support]
186
+ concept_coeffs = [coeffs[idx].item() for idx in support]
187
+ concept_image_ids = None
188
+ if self.concept_pool.image_ids is not None:
189
+ concept_image_ids = [self.concept_pool.image_ids[idx] for idx in support]
190
+
191
+ # Compute fidelity: cosine similarity between original and reconstruction
192
+ reconstruction = torch.zeros_like(v_hat)
193
+ for idx, coeff in zip(support, concept_coeffs):
194
+ reconstruction += coeff * self.centered_concepts[idx]
195
+ fidelity = F.cosine_similarity(
196
+ v_hat.unsqueeze(0), reconstruction.unsqueeze(0)
197
+ ).item()
198
+
199
+ sv_interpretations.append(SingularVectorInterpretation(
200
+ layer_idx=layer_idx,
201
+ head_idx=head_idx,
202
+ sv_idx=i,
203
+ singular_value=sigma[i].item(),
204
+ concepts=concept_captions,
205
+ concept_indices=support,
206
+ coefficients=concept_coeffs,
207
+ fidelity=fidelity,
208
+ image_ids=concept_image_ids,
209
+ ))
210
+
211
+ return HeadInterpretation(
212
+ layer_idx=layer_idx,
213
+ head_idx=head_idx,
214
+ singular_vectors=sv_interpretations,
215
+ )
216
+
217
+ def analyze_layer(
218
+ self,
219
+ layer_idx: int,
220
+ n_singular_vectors: int = 5,
221
+ K: int = 5,
222
+ lambda_coh: float = 0.3,
223
+ method: str = "comp",
224
+ ) -> List[HeadInterpretation]:
225
+ """Analyze all heads in a layer."""
226
+ results = []
227
+ for h in range(self.n_heads):
228
+ print(f" Analyzing head {h}/{self.n_heads}...")
229
+ result = self.analyze_head(
230
+ layer_idx, h, n_singular_vectors, K, lambda_coh, method
231
+ )
232
+ results.append(result)
233
+ return results
234
+
235
+ def analyze_model(
236
+ self,
237
+ layers: Optional[List[int]] = None,
238
+ n_singular_vectors: int = 5,
239
+ K: int = 5,
240
+ lambda_coh: float = 0.3,
241
+ method: str = "comp",
242
+ ) -> Dict[int, List[HeadInterpretation]]:
243
+ """
244
+ Analyze multiple layers of the model.
245
+
246
+ Args:
247
+ layers: List of layer indices. If None, analyzes last 4 layers.
248
+ n_singular_vectors: Number of top singular vectors per head
249
+ K: Concepts per singular vector
250
+ lambda_coh: COMP coherence weight
251
+ method: "comp" or "top_k"
252
+
253
+ Returns:
254
+ Dict mapping layer_idx -> list of HeadInterpretations
255
+ """
256
+ if layers is None:
257
+ n_layers = self.extractor._get_num_layers()
258
+ layers = list(range(max(0, n_layers - 4), n_layers))
259
+
260
+ results = {}
261
+ for layer_idx in layers:
262
+ print(f"Analyzing layer {layer_idx}...")
263
+ results[layer_idx] = self.analyze_layer(
264
+ layer_idx, n_singular_vectors, K, lambda_coh, method
265
+ )
266
+
267
+ return results
268
+
269
+ def edit_model(
270
+ self,
271
+ layer_idx: int,
272
+ head_idx: int,
273
+ sv_indices: List[int],
274
+ scale_factors: List[float],
275
+ ) -> None:
276
+ """
277
+ Edit the model by scaling specific singular values.
278
+
279
+ This enables:
280
+ - Suppressing concepts (scale -> 0): remove spurious features
281
+ - Amplifying concepts (scale > 1): enhance task-relevant features
282
+
283
+ Args:
284
+ layer_idx: Layer to edit
285
+ head_idx: Head to edit
286
+ sv_indices: Indices of singular vectors to modify
287
+ scale_factors: Scaling factor for each (0 = suppress, >1 = amplify)
288
+ """
289
+ # Get original W_VO
290
+ W_VO_all = self.extractor.compute_WVO(layer_idx, fold_ln=False, project_ones=False)
291
+ W_VO_h = W_VO_all[head_idx]
292
+
293
+ # SVD decompose
294
+ U, sigma, Vt = torch.linalg.svd(W_VO_h, full_matrices=False)
295
+
296
+ # Scale selected singular values
297
+ for sv_idx, scale in zip(sv_indices, scale_factors):
298
+ sigma[sv_idx] *= scale
299
+
300
+ # Reconstruct
301
+ W_VO_edited = U @ torch.diag(sigma) @ Vt
302
+
303
+ # Write back to the model
304
+ # W_VO = W_V^T @ W_O^T, so we need to update W_V and W_O
305
+ # Simplest approach: perform low-rank update on W_V
306
+ # Since W_VO = W_V_h^T @ W_O_h^T and we want W_VO_edited,
307
+ # we can set W_V_h_new such that W_V_h_new^T @ W_O_h^T = W_VO_edited
308
+ # This is: W_V_h_new^T = W_VO_edited @ (W_O_h^T)^(-1)
309
+ # But W_O_h is rank d_h, so not invertible in D x D space.
310
+ #
311
+ # Alternative: directly edit the singular values in the SVD of the
312
+ # original (non-folded) W_VO by identifying correspondence.
313
+ #
314
+ # For simplicity, we reconstruct W_VO and decompose into W_V and W_O
315
+ # via the original head dimension factorization.
316
+ self._write_WVO_to_model(layer_idx, head_idx, W_VO_edited)
317
+
318
+ def _write_WVO_to_model(
319
+ self,
320
+ layer_idx: int,
321
+ head_idx: int,
322
+ W_VO_edited: torch.Tensor,
323
+ ):
324
+ """
325
+ Write an edited W_VO back to the model weights.
326
+
327
+ Since W_VO = W_V_h^T @ W_O_h^T and has rank d_h, we can use SVD
328
+ to factorize W_VO_edited into new W_V_h and W_O_h.
329
+
330
+ W_VO_edited = U_e @ S_e @ V_e^T
331
+ Take top-d_h components:
332
+ W_V_h_new^T = U_e[:, :d_h] @ sqrt(S_e[:d_h])
333
+ W_O_h_new^T = sqrt(S_e[:d_h]) @ V_e[:d_h, :]
334
+ """
335
+ d_h = self.extractor.head_dim
336
+
337
+ # SVD of edited W_VO
338
+ U_e, S_e, Vt_e = torch.linalg.svd(W_VO_edited, full_matrices=False)
339
+
340
+ # Keep top d_h components
341
+ sqrt_S = torch.sqrt(S_e[:d_h])
342
+
343
+ # New W_V_h^T = U_e[:, :d_h] @ diag(sqrt_S) => shape [D, d_h]
344
+ # So W_V_h = (U_e[:, :d_h] @ diag(sqrt_S))^T = diag(sqrt_S) @ U_e[:, :d_h]^T
345
+ # => W_V_h shape [d_h, D]
346
+ W_V_h_new = (sqrt_S.unsqueeze(1) * U_e[:, :d_h].T) # [d_h, D]
347
+
348
+ # New W_O_h^T = diag(sqrt_S) @ Vt_e[:d_h, :] => shape [d_h, D]
349
+ # So W_O_h = (diag(sqrt_S) @ Vt_e[:d_h, :])^T = Vt_e[:d_h, :]^T @ diag(sqrt_S)
350
+ # => W_O_h shape [D, d_h]
351
+ W_O_h_new = (Vt_e[:d_h, :].T * sqrt_S.unsqueeze(0)) # [D, d_h]
352
+
353
+ # Write W_V_h back
354
+ _, _, W_V = self.extractor._get_qkv_weights(layer_idx)
355
+ W_O = self.extractor._get_output_weight(layer_idx)
356
+
357
+ h = head_idx
358
+ d_h = self.extractor.head_dim
359
+
360
+ # W_V is [d_model, d_model], head h occupies rows [h*d_h : (h+1)*d_h]
361
+ W_V[h * d_h : (h + 1) * d_h, :] = W_V_h_new
362
+
363
+ # W_O is [d_model, d_model], head h occupies columns [h*d_h : (h+1)*d_h]
364
+ W_O[:, h * d_h : (h + 1) * d_h] = W_O_h_new
365
+
366
+ def find_concept_heads(
367
+ self,
368
+ target_concepts: List[str],
369
+ concept_embeddings: torch.Tensor,
370
+ layers: Optional[List[int]] = None,
371
+ n_singular_vectors: int = 10,
372
+ K: int = 5,
373
+ lambda_coh: float = 0.3,
374
+ threshold: float = 0.3,
375
+ ) -> List[Dict]:
376
+ """
377
+ Find attention heads that encode specific concepts.
378
+
379
+ Useful for targeted model editing: find which heads encode
380
+ "background" features, "texture" features, etc.
381
+
382
+ Args:
383
+ target_concepts: List of target concept descriptions
384
+ concept_embeddings: [n_targets, d] embeddings of target concepts
385
+ layers: Layers to search
386
+ n_singular_vectors: SVs per head to check
387
+ K: Concepts per SV
388
+ lambda_coh: COMP coherence weight
389
+ threshold: Minimum similarity to consider a match
390
+
391
+ Returns:
392
+ List of dicts with head locations and matching info
393
+ """
394
+ results = self.analyze_model(
395
+ layers=layers,
396
+ n_singular_vectors=n_singular_vectors,
397
+ K=K,
398
+ lambda_coh=lambda_coh,
399
+ )
400
+
401
+ matches = []
402
+ concept_embeddings = F.normalize(concept_embeddings.to(self.device), dim=-1)
403
+
404
+ for layer_idx, heads in results.items():
405
+ for head_interp in heads:
406
+ for sv_interp in head_interp.singular_vectors:
407
+ # Check if any of the attributed concepts match targets
408
+ for ci, concept_idx in enumerate(sv_interp.concept_indices):
409
+ concept_emb = self.centered_concepts[concept_idx]
410
+ sims = (concept_embeddings @ concept_emb).tolist()
411
+ max_sim = max(sims)
412
+ if max_sim > threshold:
413
+ matches.append({
414
+ "layer": layer_idx,
415
+ "head": head_interp.head_idx,
416
+ "sv_index": sv_interp.sv_idx,
417
+ "concept": sv_interp.concepts[ci],
418
+ "coefficient": sv_interp.coefficients[ci],
419
+ "target_similarity": max_sim,
420
+ "singular_value": sv_interp.singular_value,
421
+ })
422
+
423
+ # Sort by relevance (target_similarity * singular_value * coefficient)
424
+ matches.sort(
425
+ key=lambda x: x["target_similarity"] * x["singular_value"] * x["coefficient"],
426
+ reverse=True,
427
+ )
428
+
429
+ return matches
430
+
431
+ @staticmethod
432
+ def save_results(
433
+ results: Dict[int, List[HeadInterpretation]],
434
+ path: str,
435
+ ):
436
+ """Save analysis results to JSON."""
437
+ serialized = {}
438
+ for layer_idx, heads in results.items():
439
+ serialized[str(layer_idx)] = [h.to_dict() for h in heads]
440
+
441
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
442
+ with open(path, "w") as f:
443
+ json.dump(serialized, f, indent=2)
444
+ print(f"Results saved to {path}")
445
+
446
+ @staticmethod
447
+ def load_results(path: str) -> Dict:
448
+ """Load analysis results from JSON."""
449
+ with open(path) as f:
450
+ return json.load(f)
unimodal_sith/weight_extraction.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Weight extraction utilities for various ViT architectures.
3
+
4
+ Supports:
5
+ - DINOv2 (facebook/dinov2-*)
6
+ - CLIP ViT (openai/clip-vit-* via HuggingFace transformers)
7
+ - Any HuggingFace ViT (google/vit-*)
8
+
9
+ For each architecture, extracts:
10
+ - W_V (value projection) and W_O (output projection) per attention head
11
+ - W_VO = W_V^T @ W_O^T (the value-output matrix, as in SITH)
12
+ - LayerNorm parameters for folding
13
+ - Final projection matrix W_p (if present)
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from typing import Dict, List, Optional, Tuple
19
+
20
+
21
+ def fold_layernorm_into_weights(
22
+ W: torch.Tensor,
23
+ ln_weight: torch.Tensor,
24
+ ln_bias: torch.Tensor,
25
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ """
27
+ Fold LayerNorm affine parameters into a weight matrix.
28
+
29
+ Given: LN(x) = (x - mean) / std * w + b
30
+ The affine part: x_affine = diag(w) @ x + b
31
+ Folded: W' = diag(w) @ W, b_proj = W^T @ b (absorbed into bias)
32
+
33
+ Args:
34
+ W: Weight matrix [out_dim, in_dim] (applied as x @ W^T)
35
+ ln_weight: LayerNorm weight [in_dim]
36
+ ln_bias: LayerNorm bias [in_dim]
37
+
38
+ Returns:
39
+ W_folded: [out_dim, in_dim]
40
+ b_folded: [out_dim]
41
+ """
42
+ # W' = W @ diag(ln_weight) (since W acts as x @ W^T, LN acts on x first)
43
+ # Actually: LN(x) @ W^T = (diag(w) @ x + b) @ W^T = x @ diag(w) @ W^T + b @ W^T
44
+ # So W_folded = W @ diag(w) when W is [out, in] and applied as x @ W^T
45
+ # But in PyTorch Linear: y = x @ W^T + bias
46
+ # LN(x) @ W^T = (x * w + b) @ W^T = x @ diag(w) @ W^T + b @ W^T
47
+ # => W_folded^T = diag(w) @ W^T => W_folded = W @ diag(w)... no:
48
+ # W_folded = diag(w) @ W when W is [out, in] and we want W_folded to replace W
49
+ # such that x @ W_folded^T = (x * w) @ W^T = x @ (diag(w) @ W)^T = x @ W^T @ diag(w)
50
+ # Hmm, let's be precise:
51
+ # y = (x * w + b) @ W^T = (x @ diag(w)) @ W^T + b @ W^T
52
+ # = x @ (W @ diag(w))^T + b @ W^T
53
+ # So: W_folded = W @ diag(w) ... no wait:
54
+ # x @ (W @ diag(w))^T = x @ diag(w)^T @ W^T = x @ diag(w) @ W^T ✓
55
+ # Hmm, diag(w)^T = diag(w), so: W_folded such that x @ W_folded^T = x @ diag(w) @ W^T
56
+ # => W_folded^T = diag(w) @ W^T => W_folded = W @ diag(w)
57
+
58
+ # Actually more carefully:
59
+ # For W [out_dim, in_dim], y = x W^T
60
+ # After LN: y' = LN(x) W^T = (x*w + b) W^T
61
+ # = (diag(w) x^T)^T W^T + b W^T ... nope, element-wise
62
+ # x*w is element-wise: x*w = x @ diag(w) (treating x as [1, in_dim])
63
+ # So y' = x @ diag(w) @ W^T + b @ W^T
64
+ # New W_folded: W_folded = W @ diag(w) (then x @ W_folded^T = x @ diag(w) @ W^T ✓)
65
+ # Wait: (W @ diag(w))^T = diag(w) @ W^T, so x @ (W @ diag(w))^T = x @ diag(w) @ W^T ✓
66
+
67
+ W_folded = W * ln_weight.unsqueeze(0) # Broadcast: [out, in] * [1, in] = [out, in]
68
+ b_folded = ln_bias @ W.t() # [in] @ [in, out] = [out]
69
+
70
+ return W_folded, b_folded
71
+
72
+
73
+ def project_out_ones(W: torch.Tensor) -> torch.Tensor:
74
+ """
75
+ Project weight matrix columns onto the subspace orthogonal to the all-ones direction.
76
+ This accounts for the centering operation of LayerNorm.
77
+
78
+ For a matrix W [D, D], we subtract the mean of each column from itself.
79
+ Equivalently: W_proj = W - (1/D) * ones @ ones^T @ W
80
+ """
81
+ D = W.shape[0]
82
+ col_means = W.mean(dim=0, keepdim=True) # [1, D]
83
+ W_proj = W - col_means
84
+ return W_proj
85
+
86
+
87
+ class WeightExtractor:
88
+ """
89
+ Extracts and processes attention head weights for SITH analysis.
90
+ Architecture-agnostic: supports DINOv2, CLIP ViT, standard ViT.
91
+ """
92
+
93
+ SUPPORTED_ARCHITECTURES = ["dinov2", "clip", "vit"]
94
+
95
+ def __init__(self, model: nn.Module, architecture: str, n_heads: int, d_model: int):
96
+ """
97
+ Args:
98
+ model: The loaded model (HuggingFace transformers model)
99
+ architecture: One of "dinov2", "clip", "vit"
100
+ n_heads: Number of attention heads
101
+ d_model: Hidden dimension
102
+ """
103
+ assert architecture in self.SUPPORTED_ARCHITECTURES, \
104
+ f"Unsupported architecture: {architecture}. Use one of {self.SUPPORTED_ARCHITECTURES}"
105
+
106
+ self.model = model
107
+ self.architecture = architecture
108
+ self.n_heads = n_heads
109
+ self.d_model = d_model
110
+ self.head_dim = d_model // n_heads
111
+
112
+ def _get_layer(self, layer_idx: int):
113
+ """Get the transformer layer by index."""
114
+ if self.architecture == "dinov2":
115
+ return self.model.encoder.layer[layer_idx]
116
+ elif self.architecture == "clip":
117
+ return self.model.vision_model.encoder.layers[layer_idx]
118
+ elif self.architecture == "vit":
119
+ # AutoModel for ViT doesn't have .vit prefix
120
+ if hasattr(self.model, 'vit'):
121
+ return self.model.vit.encoder.layer[layer_idx]
122
+ else:
123
+ return self.model.encoder.layer[layer_idx]
124
+ else:
125
+ raise ValueError(f"Unknown architecture: {self.architecture}")
126
+
127
+ def _get_num_layers(self) -> int:
128
+ """Get total number of transformer layers."""
129
+ if self.architecture == "dinov2":
130
+ return len(self.model.encoder.layer)
131
+ elif self.architecture == "clip":
132
+ return len(self.model.vision_model.encoder.layers)
133
+ elif self.architecture == "vit":
134
+ if hasattr(self.model, 'vit'):
135
+ return len(self.model.vit.encoder.layer)
136
+ else:
137
+ return len(self.model.encoder.layer)
138
+ else:
139
+ raise ValueError(f"Unknown architecture: {self.architecture}")
140
+
141
+ def _get_qkv_weights(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
142
+ """Extract Q, K, V weight matrices from a layer."""
143
+ layer = self._get_layer(layer_idx)
144
+
145
+ if self.architecture == "dinov2":
146
+ attn = layer.attention.attention
147
+ W_Q = attn.query.weight.data # [d_model, d_model]
148
+ W_K = attn.key.weight.data
149
+ W_V = attn.value.weight.data
150
+ elif self.architecture == "clip":
151
+ attn = layer.self_attn
152
+ W_Q = attn.q_proj.weight.data
153
+ W_K = attn.k_proj.weight.data
154
+ W_V = attn.v_proj.weight.data
155
+ elif self.architecture == "vit":
156
+ attn = layer.attention.attention
157
+ W_Q = attn.query.weight.data
158
+ W_K = attn.key.weight.data
159
+ W_V = attn.value.weight.data
160
+
161
+ return W_Q, W_K, W_V
162
+
163
+ def _get_output_weight(self, layer_idx: int) -> torch.Tensor:
164
+ """Extract output projection weight matrix."""
165
+ layer = self._get_layer(layer_idx)
166
+
167
+ if self.architecture == "dinov2":
168
+ return layer.attention.output.dense.weight.data # [d_model, d_model]
169
+ elif self.architecture == "clip":
170
+ return layer.self_attn.out_proj.weight.data
171
+ elif self.architecture == "vit":
172
+ return layer.attention.output.dense.weight.data
173
+
174
+ def _get_pre_attn_layernorm(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ """Get the LayerNorm weight and bias that precedes the attention block."""
176
+ layer = self._get_layer(layer_idx)
177
+
178
+ if self.architecture == "dinov2":
179
+ ln = layer.norm1
180
+ elif self.architecture == "clip":
181
+ ln = layer.layer_norm1
182
+ elif self.architecture == "vit":
183
+ ln = layer.layernorm_before
184
+
185
+ return ln.weight.data, ln.bias.data
186
+
187
+ def _get_final_layernorm(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
188
+ """Get the final LayerNorm (applied before projection, if present)."""
189
+ if self.architecture == "dinov2":
190
+ ln = self.model.layernorm
191
+ elif self.architecture == "clip":
192
+ ln = self.model.vision_model.post_layernorm
193
+ elif self.architecture == "vit":
194
+ if hasattr(self.model, 'vit'):
195
+ ln = self.model.vit.layernorm
196
+ else:
197
+ ln = self.model.layernorm
198
+
199
+ return ln.weight.data, ln.bias.data
200
+
201
+ def _get_projection_matrix(self) -> Optional[torch.Tensor]:
202
+ """Get the final projection matrix W_p (maps hidden dim to output dim)."""
203
+ if self.architecture == "clip":
204
+ # CLIP has a visual projection: [proj_dim, d_model]
205
+ # Applied as: features = cls_token @ W_p^T
206
+ try:
207
+ W_p = self.model.visual_projection.weight.data # [proj_dim, d_model]
208
+ return W_p.t() # Return as [d_model, proj_dim]
209
+ except AttributeError:
210
+ return None
211
+ elif self.architecture == "dinov2":
212
+ # DINOv2 has no projection matrix
213
+ return None
214
+ elif self.architecture == "vit":
215
+ return None
216
+
217
+ def _get_layerscale(self, layer_idx: int) -> Optional[torch.Tensor]:
218
+ """Get LayerScale parameter (DINOv2 specific)."""
219
+ if self.architecture == "dinov2":
220
+ layer = self._get_layer(layer_idx)
221
+ try:
222
+ return layer.layer_scale1.lambda1.data # [d_model]
223
+ except AttributeError:
224
+ return None
225
+ return None
226
+
227
+ def compute_WVO(
228
+ self,
229
+ layer_idx: int,
230
+ fold_ln: bool = True,
231
+ project_ones: bool = True,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Compute the Value-Output (VO) weight matrix for all heads in a layer.
235
+
236
+ W_VO_h = W_V_h^T @ W_O_h^T where:
237
+ - W_V_h is [head_dim, d_model] (head h's slice of W_V)
238
+ - W_O_h is [d_model, head_dim] (head h's slice of W_O)
239
+ - W_VO_h is [d_model, d_model] (rank head_dim)
240
+
241
+ Following the paper:
242
+ H'_h(X) = softmax(...) @ X @ W_VO_h
243
+ where W_VO_h = W_V_h @ W_O_h, with W_V [d_model, d_model] and W_O [d_model, d_model]
244
+
245
+ Actually from Eq. (4): MHA(X) = sum_h A^h @ X @ W_VO^h
246
+ where W_VO^h = W_V^h @ W_O^h
247
+ W_V^h: [D, d_h], W_O^h: [d_h, D] => W_VO^h: [D, D]
248
+
249
+ In PyTorch Linear(in, out): weight is [out, in], applied as x @ W^T
250
+ So W_V.weight is [d_model, d_model] applied as x @ W_V^T
251
+ But per head: W_V_h is rows [h*d_h : (h+1)*d_h] of W_V.weight => [d_h, D]
252
+ Applied as x @ W_V_h^T => [N, D] @ [D, d_h] => [N, d_h]
253
+
254
+ W_O.weight is [D, D] applied as h_out @ W_O^T
255
+ Per head: W_O_h is columns [h*d_h : (h+1)*d_h] of W_O.weight => [D, d_h]
256
+ So h_out_h @ W_O_h^T => [N, d_h] @ [d_h, D] => [N, D] ... wait
257
+ Actually W_O.weight [D, D]: output = concat(h1..hH) @ W_O^T
258
+ concat has shape [N, H*d_h] = [N, D]
259
+ But per-head: W_O_h = W_O.weight[:, h*d_h:(h+1)*d_h] => [D, d_h]
260
+ Applied: h_out_h [N, d_h] @ W_O_h^T [d_h, D] => [N, D]
261
+
262
+ Full per-head: output_h = A_h @ X @ W_V_h^T @ W_O_h^T
263
+ W_VO_h = W_V_h^T @ W_O_h^T = [D, d_h] @ [d_h, D] = [D, D]
264
+
265
+ Returns:
266
+ W_VO: [n_heads, d_model, d_model]
267
+ """
268
+ _, _, W_V = self._get_qkv_weights(layer_idx)
269
+ W_O = self._get_output_weight(layer_idx)
270
+
271
+ # Optionally fold LayerNorm into W_V
272
+ if fold_ln:
273
+ ln_weight, ln_bias = self._get_pre_attn_layernorm(layer_idx)
274
+ # W_V is [d_model, d_model], LN acts on input x before attention
275
+ # x_ln = x * w + b (element-wise), then x_ln @ W_V^T
276
+ # = (x @ diag(w)) @ W_V^T + b @ W_V^T
277
+ # Folded W_V: such that x @ W_V_folded^T = x @ diag(w) @ W_V^T
278
+ # => W_V_folded^T = diag(w) @ W_V^T => W_V_folded = W_V @ diag(w)
279
+ # But W_V is [out=d_model, in=d_model] and diag(w) acts on input dim
280
+ W_V = W_V * ln_weight.unsqueeze(0) # [d_model, d_model] * [1, d_model]
281
+
282
+ # Fold LayerScale if present (DINOv2)
283
+ ls = self._get_layerscale(layer_idx)
284
+ if ls is not None:
285
+ # LayerScale is applied after attention output: output = ls * attn_output
286
+ # So W_O_effective = diag(ls) @ W_O
287
+ # In our notation: output_h = A_h @ X @ W_V_h^T @ W_O_h^T * ls
288
+ # = A_h @ X @ W_V_h^T @ (diag(ls) @ W_O_h)^T
289
+ # W_O is [D, D], ls is [D]
290
+ W_O = W_O * ls.unsqueeze(1) # [D, D] * [D, 1] = [D, D]
291
+
292
+ # Split into per-head matrices
293
+ # W_V: [d_model, d_model] -> W_V_h: [d_h, d_model] for head h
294
+ W_V_per_head = W_V.view(self.n_heads, self.head_dim, self.d_model) # [H, d_h, D]
295
+
296
+ # W_O: [d_model, d_model] -> W_O_h: [d_model, d_h] for head h
297
+ # W_O[:, h*d_h:(h+1)*d_h] => reshaped
298
+ W_O_per_head = W_O.view(self.d_model, self.n_heads, self.head_dim) # [D, H, d_h]
299
+ W_O_per_head = W_O_per_head.permute(1, 0, 2) # [H, D, d_h]
300
+
301
+ # W_VO_h = W_V_h^T @ W_O_h^T = [D, d_h] @ [d_h, D] = [D, D]
302
+ # W_V_h^T: [D, d_h], W_O_h^T: [d_h, D]
303
+ W_VO = torch.bmm(
304
+ W_V_per_head.transpose(1, 2), # [H, D, d_h]
305
+ W_O_per_head.transpose(1, 2), # [H, d_h, D]
306
+ ) # [H, D, D]
307
+
308
+ # Project out the all-ones direction (centering from LN)
309
+ if project_ones:
310
+ for h in range(self.n_heads):
311
+ W_VO[h] = project_out_ones(W_VO[h])
312
+
313
+ return W_VO
314
+
315
+ def svd_decompose(
316
+ self,
317
+ W_VO_h: torch.Tensor,
318
+ top_k: Optional[int] = None,
319
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
320
+ """
321
+ Decompose a per-head W_VO matrix via SVD.
322
+
323
+ W_VO = U @ diag(sigma) @ V^T
324
+
325
+ Args:
326
+ W_VO_h: [d_model, d_model] VO matrix for a single head
327
+ top_k: If set, only return top-k singular vectors
328
+
329
+ Returns:
330
+ U: [d_model, r] left singular vectors (reading directions)
331
+ sigma: [r] singular values
332
+ Vt: [r, d_model] right singular vectors (writing directions)
333
+ """
334
+ U, sigma, Vt = torch.linalg.svd(W_VO_h, full_matrices=False)
335
+
336
+ if top_k is not None:
337
+ U = U[:, :top_k]
338
+ sigma = sigma[:top_k]
339
+ Vt = Vt[:top_k, :]
340
+
341
+ return U, sigma, Vt
342
+
343
+ def project_to_feature_space(
344
+ self,
345
+ vectors: torch.Tensor,
346
+ ) -> torch.Tensor:
347
+ """
348
+ Project singular vectors from the residual stream to the model's output feature space.
349
+
350
+ For CLIP: apply final LN then W_p projection
351
+ For DINOv2/ViT: apply final LN (no projection matrix)
352
+
353
+ Args:
354
+ vectors: [n, d_model] singular vectors in residual stream space
355
+
356
+ Returns:
357
+ projected: [n, d_out] vectors in the output feature space, L2-normalized
358
+ """
359
+ # Get final LayerNorm
360
+ ln_w, ln_b = self._get_final_layernorm()
361
+
362
+ # Apply LN affine transformation (without the data-dependent normalization)
363
+ # Since these are abstract directions (not activations), we just apply the affine part
364
+ # v_ln = v * ln_weight + ln_bias
365
+ vectors_ln = vectors * ln_w.unsqueeze(0) + ln_b.unsqueeze(0)
366
+
367
+ # Apply projection if present
368
+ W_p = self._get_projection_matrix()
369
+ if W_p is not None:
370
+ # W_p is [d_model, proj_dim]
371
+ vectors_proj = vectors_ln @ W_p # [n, proj_dim]
372
+ else:
373
+ vectors_proj = vectors_ln
374
+
375
+ # L2 normalize
376
+ vectors_proj = torch.nn.functional.normalize(vectors_proj, dim=-1)
377
+
378
+ return vectors_proj