Add UniSITH source code: weight_extraction, concept_pool, comp, unisith
Browse files- unimodal_sith/__pycache__/comp.cpython-312.pyc +0 -0
- unimodal_sith/__pycache__/concept_pool.cpython-312.pyc +0 -0
- unimodal_sith/__pycache__/unisith.cpython-312.pyc +0 -0
- unimodal_sith/__pycache__/weight_extraction.cpython-312.pyc +0 -0
- unimodal_sith/comp.py +158 -0
- unimodal_sith/concept_pool.py +194 -0
- unimodal_sith/unisith.py +450 -0
- unimodal_sith/weight_extraction.py +378 -0
unimodal_sith/__pycache__/comp.cpython-312.pyc
ADDED
|
Binary file (5.88 kB). View file
|
|
|
unimodal_sith/__pycache__/concept_pool.cpython-312.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
unimodal_sith/__pycache__/unisith.cpython-312.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
unimodal_sith/__pycache__/weight_extraction.cpython-312.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
unimodal_sith/comp.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
COMP: Coherent Orthogonal Matching Pursuit
|
| 3 |
+
|
| 4 |
+
Adapted from SITH (Vaquero et al., 2025), Algorithm 1.
|
| 5 |
+
|
| 6 |
+
Given a singular vector v_hat and a concept dictionary Gamma_hat, COMP finds
|
| 7 |
+
a sparse, semantically coherent combination of K concepts that best
|
| 8 |
+
approximates v_hat.
|
| 9 |
+
|
| 10 |
+
This implementation works with both text concept embeddings (original SITH)
|
| 11 |
+
and image concept embeddings (UniSITH).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
from scipy.optimize import nnls
|
| 17 |
+
from typing import List, Tuple, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def comp(
|
| 21 |
+
v_hat: torch.Tensor,
|
| 22 |
+
Gamma_hat: torch.Tensor,
|
| 23 |
+
K: int = 5,
|
| 24 |
+
lambda_coh: float = 0.3,
|
| 25 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 26 |
+
"""
|
| 27 |
+
Coherent Orthogonal Matching Pursuit (COMP).
|
| 28 |
+
|
| 29 |
+
Extends Non-Negative Orthogonal Matching Pursuit (NNOMP) by incorporating
|
| 30 |
+
a coherence term that encourages semantically coherent concept selections.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
v_hat: [d] projected singular vector (L2-normalized)
|
| 34 |
+
Gamma_hat: [C, d] concept embedding matrix (L2-normalized rows)
|
| 35 |
+
K: Sparsity level (number of concepts to select)
|
| 36 |
+
lambda_coh: Coherence weight (λ in the paper, default 0.3)
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
c: [C] sparse coefficient vector (non-negative)
|
| 40 |
+
support: List of K selected concept indices
|
| 41 |
+
"""
|
| 42 |
+
C, d = Gamma_hat.shape
|
| 43 |
+
device = v_hat.device
|
| 44 |
+
|
| 45 |
+
# Move to CPU for scipy nnls
|
| 46 |
+
v_hat_np = v_hat.cpu().numpy().astype(np.float64)
|
| 47 |
+
Gamma_np = Gamma_hat.cpu().numpy().astype(np.float64)
|
| 48 |
+
|
| 49 |
+
# Initialize
|
| 50 |
+
r = v_hat_np.copy() # Residual
|
| 51 |
+
S = [] # Support set (selected concept indices)
|
| 52 |
+
c = np.zeros(C)
|
| 53 |
+
|
| 54 |
+
# Precompute concept-concept similarity matrix (for coherence)
|
| 55 |
+
# Only compute upper triangle for efficiency - but we'll compute on the fly
|
| 56 |
+
# since C can be very large
|
| 57 |
+
|
| 58 |
+
for k in range(K):
|
| 59 |
+
# Step 1: Compute correlations with residual
|
| 60 |
+
s_res = Gamma_np @ r # [C]
|
| 61 |
+
|
| 62 |
+
# Step 2: Compute coherence scores
|
| 63 |
+
s_coh = np.zeros(C)
|
| 64 |
+
if len(S) > 0:
|
| 65 |
+
# Average similarity of each candidate to already-selected concepts
|
| 66 |
+
S_embeddings = Gamma_np[S] # [|S|, d]
|
| 67 |
+
# Similarity of all concepts to selected ones
|
| 68 |
+
sim_to_selected = Gamma_np @ S_embeddings.T # [C, |S|]
|
| 69 |
+
s_coh = sim_to_selected.mean(axis=1) # [C]
|
| 70 |
+
# Zero out already selected
|
| 71 |
+
for idx in S:
|
| 72 |
+
s_coh[idx] = -np.inf
|
| 73 |
+
|
| 74 |
+
# Step 3: Combined score
|
| 75 |
+
s_final = s_res + lambda_coh * s_coh
|
| 76 |
+
|
| 77 |
+
# Mask already selected concepts
|
| 78 |
+
for idx in S:
|
| 79 |
+
s_final[idx] = -np.inf
|
| 80 |
+
|
| 81 |
+
# Step 4: Greedy selection
|
| 82 |
+
j_k = int(np.argmax(s_final))
|
| 83 |
+
S.append(j_k)
|
| 84 |
+
|
| 85 |
+
# Step 5: Non-negative least squares on current support
|
| 86 |
+
G_S = Gamma_np[S].T # [d, |S|] - columns are selected concept embeddings
|
| 87 |
+
c_S, _ = nnls(G_S, v_hat_np) # min ||v_hat - G_S @ c_S||^2, c_S >= 0
|
| 88 |
+
|
| 89 |
+
# Step 6: Update residual
|
| 90 |
+
r = v_hat_np - G_S @ c_S
|
| 91 |
+
|
| 92 |
+
# Construct final coefficient vector
|
| 93 |
+
c = np.zeros(C)
|
| 94 |
+
for i, j in enumerate(S):
|
| 95 |
+
c[j] = c_S[i]
|
| 96 |
+
|
| 97 |
+
return torch.tensor(c, dtype=torch.float32, device=device), S
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def comp_batch(
|
| 101 |
+
V_hat: torch.Tensor,
|
| 102 |
+
Gamma_hat: torch.Tensor,
|
| 103 |
+
K: int = 5,
|
| 104 |
+
lambda_coh: float = 0.3,
|
| 105 |
+
) -> Tuple[torch.Tensor, List[List[int]]]:
|
| 106 |
+
"""
|
| 107 |
+
Apply COMP to multiple singular vectors.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
V_hat: [n, d] batch of projected singular vectors
|
| 111 |
+
Gamma_hat: [C, d] concept embedding matrix
|
| 112 |
+
K: Sparsity level
|
| 113 |
+
lambda_coh: Coherence weight
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
C_mat: [n, C] coefficient matrix
|
| 117 |
+
supports: List of n support sets
|
| 118 |
+
"""
|
| 119 |
+
n = V_hat.shape[0]
|
| 120 |
+
C = Gamma_hat.shape[0]
|
| 121 |
+
|
| 122 |
+
C_mat = torch.zeros(n, C, device=V_hat.device)
|
| 123 |
+
supports = []
|
| 124 |
+
|
| 125 |
+
for i in range(n):
|
| 126 |
+
c_i, support_i = comp(V_hat[i], Gamma_hat, K=K, lambda_coh=lambda_coh)
|
| 127 |
+
C_mat[i] = c_i
|
| 128 |
+
supports.append(support_i)
|
| 129 |
+
|
| 130 |
+
return C_mat, supports
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def top_k_selection(
|
| 134 |
+
v_hat: torch.Tensor,
|
| 135 |
+
Gamma_hat: torch.Tensor,
|
| 136 |
+
K: int = 5,
|
| 137 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 138 |
+
"""
|
| 139 |
+
Simple top-K selection baseline: pick the K most similar concepts.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
v_hat: [d] projected singular vector
|
| 143 |
+
Gamma_hat: [C, d] concept embedding matrix
|
| 144 |
+
K: Number of concepts to select
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
c: [C] coefficient vector (similarity scores for top-K, 0 elsewhere)
|
| 148 |
+
support: List of K selected concept indices
|
| 149 |
+
"""
|
| 150 |
+
similarities = Gamma_hat @ v_hat # [C]
|
| 151 |
+
top_k_vals, top_k_idx = torch.topk(similarities, K)
|
| 152 |
+
|
| 153 |
+
c = torch.zeros(Gamma_hat.shape[0], device=v_hat.device)
|
| 154 |
+
support = top_k_idx.tolist()
|
| 155 |
+
for i, idx in enumerate(support):
|
| 156 |
+
c[idx] = max(0, top_k_vals[i].item()) # Non-negative
|
| 157 |
+
|
| 158 |
+
return c, support
|
unimodal_sith/concept_pool.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visual Concept Pool for UniSITH.
|
| 3 |
+
|
| 4 |
+
Instead of text concepts (ConceptNet strings + CLIP text encoder),
|
| 5 |
+
we use captioned images as the concept pool.
|
| 6 |
+
|
| 7 |
+
Each concept is an image from a captioned dataset, and the corresponding
|
| 8 |
+
caption provides human-interpretable meaning.
|
| 9 |
+
|
| 10 |
+
The concept embeddings are computed by encoding each image through the
|
| 11 |
+
same unimodal vision model being analyzed.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import numpy as np
|
| 17 |
+
from typing import Dict, List, Optional, Tuple
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import os
|
| 21 |
+
import json
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VisualConceptPool:
|
| 25 |
+
"""
|
| 26 |
+
A pool of visual concepts, each represented by:
|
| 27 |
+
- An image embedding (computed by the model being analyzed)
|
| 28 |
+
- A caption (for human interpretability)
|
| 29 |
+
- Optionally, the original image
|
| 30 |
+
|
| 31 |
+
Analogous to Γ = {γ_1, ..., γ_C} in SITH, but each γ_i is an image
|
| 32 |
+
embedding rather than a text embedding.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
embeddings: torch.Tensor,
|
| 38 |
+
captions: List[str],
|
| 39 |
+
image_ids: Optional[List[int]] = None,
|
| 40 |
+
metadata: Optional[Dict] = None,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
embeddings: [C, d] tensor of L2-normalized concept embeddings
|
| 45 |
+
captions: List of C caption strings
|
| 46 |
+
image_ids: Optional list of C image IDs for retrieval
|
| 47 |
+
metadata: Optional metadata dict
|
| 48 |
+
"""
|
| 49 |
+
assert embeddings.shape[0] == len(captions), \
|
| 50 |
+
f"Embeddings ({embeddings.shape[0]}) and captions ({len(captions)}) must match"
|
| 51 |
+
|
| 52 |
+
self.embeddings = embeddings # [C, d]
|
| 53 |
+
self.captions = captions
|
| 54 |
+
self.image_ids = image_ids
|
| 55 |
+
self.metadata = metadata or {}
|
| 56 |
+
self.num_concepts = len(captions)
|
| 57 |
+
self.embed_dim = embeddings.shape[1]
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def from_dataset(
|
| 61 |
+
cls,
|
| 62 |
+
dataset,
|
| 63 |
+
model,
|
| 64 |
+
processor,
|
| 65 |
+
architecture: str,
|
| 66 |
+
image_column: str = "image",
|
| 67 |
+
caption_column: str = "caption",
|
| 68 |
+
image_id_column: str = "image_id",
|
| 69 |
+
batch_size: int = 64,
|
| 70 |
+
max_concepts: Optional[int] = None,
|
| 71 |
+
device: str = "cpu",
|
| 72 |
+
cache_path: Optional[str] = None,
|
| 73 |
+
) -> "VisualConceptPool":
|
| 74 |
+
"""
|
| 75 |
+
Build a concept pool from a HuggingFace dataset.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
dataset: HF dataset with image and caption columns
|
| 79 |
+
model: Vision model (HuggingFace transformers)
|
| 80 |
+
processor: Image processor/transform
|
| 81 |
+
architecture: Model architecture type
|
| 82 |
+
image_column: Column name for images
|
| 83 |
+
caption_column: Column name for captions
|
| 84 |
+
image_id_column: Column name for image IDs
|
| 85 |
+
batch_size: Batch size for encoding
|
| 86 |
+
max_concepts: Max number of concepts to use
|
| 87 |
+
device: Device for computation
|
| 88 |
+
cache_path: If set, cache embeddings to/from this path
|
| 89 |
+
"""
|
| 90 |
+
# Check for cached embeddings
|
| 91 |
+
if cache_path and os.path.exists(cache_path):
|
| 92 |
+
print(f"Loading cached concept pool from {cache_path}")
|
| 93 |
+
return cls.load(cache_path)
|
| 94 |
+
|
| 95 |
+
if max_concepts is not None:
|
| 96 |
+
dataset = dataset.select(range(min(max_concepts, len(dataset))))
|
| 97 |
+
|
| 98 |
+
captions = dataset[caption_column]
|
| 99 |
+
image_ids = None
|
| 100 |
+
if image_id_column in dataset.column_names:
|
| 101 |
+
image_ids = dataset[image_id_column]
|
| 102 |
+
|
| 103 |
+
# Encode all images
|
| 104 |
+
model = model.to(device)
|
| 105 |
+
model.eval()
|
| 106 |
+
|
| 107 |
+
all_embeddings = []
|
| 108 |
+
|
| 109 |
+
print(f"Encoding {len(dataset)} concept images...")
|
| 110 |
+
for i in tqdm(range(0, len(dataset), batch_size)):
|
| 111 |
+
batch_end = min(i + batch_size, len(dataset))
|
| 112 |
+
batch_images = [dataset[j][image_column] for j in range(i, batch_end)]
|
| 113 |
+
|
| 114 |
+
# Ensure images are RGB
|
| 115 |
+
batch_images = [img.convert("RGB") if img.mode != "RGB" else img for img in batch_images]
|
| 116 |
+
|
| 117 |
+
# Process images
|
| 118 |
+
inputs = processor(images=batch_images, return_tensors="pt").to(device)
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
if architecture == "dinov2":
|
| 122 |
+
outputs = model(**inputs)
|
| 123 |
+
embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
| 124 |
+
elif architecture == "clip":
|
| 125 |
+
# For CLIP, get the vision features
|
| 126 |
+
outputs = model.vision_model(**inputs)
|
| 127 |
+
# Get CLS token, apply post-layernorm
|
| 128 |
+
pooled = outputs.pooler_output # Already pooled + post-LN
|
| 129 |
+
# Apply visual projection
|
| 130 |
+
embeddings = model.visual_projection(pooled)
|
| 131 |
+
elif architecture == "vit":
|
| 132 |
+
outputs = model(**inputs)
|
| 133 |
+
embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
| 134 |
+
|
| 135 |
+
# L2 normalize
|
| 136 |
+
embeddings = F.normalize(embeddings, dim=-1)
|
| 137 |
+
all_embeddings.append(embeddings.cpu())
|
| 138 |
+
|
| 139 |
+
embeddings = torch.cat(all_embeddings, dim=0)
|
| 140 |
+
|
| 141 |
+
pool = cls(
|
| 142 |
+
embeddings=embeddings,
|
| 143 |
+
captions=captions,
|
| 144 |
+
image_ids=image_ids,
|
| 145 |
+
metadata={
|
| 146 |
+
"architecture": architecture,
|
| 147 |
+
"num_concepts": len(captions),
|
| 148 |
+
"embed_dim": embeddings.shape[1],
|
| 149 |
+
},
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Cache if requested
|
| 153 |
+
if cache_path:
|
| 154 |
+
pool.save(cache_path)
|
| 155 |
+
|
| 156 |
+
return pool
|
| 157 |
+
|
| 158 |
+
def get_centered_embeddings(self) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Return embeddings after mean-centering and re-normalization.
|
| 161 |
+
|
| 162 |
+
This is analogous to the modality gap correction in SITH (Eq. 18-19),
|
| 163 |
+
but for unimodal models we center within the image embedding distribution
|
| 164 |
+
to ensure the concept embeddings are centered around the origin.
|
| 165 |
+
|
| 166 |
+
This is important for matching with singular vectors, which themselves
|
| 167 |
+
are zero-centered directions.
|
| 168 |
+
"""
|
| 169 |
+
mu = self.embeddings.mean(dim=0, keepdim=True) # [1, d]
|
| 170 |
+
centered = self.embeddings - mu # [C, d]
|
| 171 |
+
centered = F.normalize(centered, dim=-1) # Re-normalize
|
| 172 |
+
return centered, mu
|
| 173 |
+
|
| 174 |
+
def save(self, path: str):
|
| 175 |
+
"""Save concept pool to disk."""
|
| 176 |
+
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
| 177 |
+
torch.save({
|
| 178 |
+
"embeddings": self.embeddings,
|
| 179 |
+
"captions": self.captions,
|
| 180 |
+
"image_ids": self.image_ids,
|
| 181 |
+
"metadata": self.metadata,
|
| 182 |
+
}, path)
|
| 183 |
+
print(f"Saved concept pool to {path}")
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def load(cls, path: str) -> "VisualConceptPool":
|
| 187 |
+
"""Load concept pool from disk."""
|
| 188 |
+
data = torch.load(path, weights_only=False)
|
| 189 |
+
return cls(
|
| 190 |
+
embeddings=data["embeddings"],
|
| 191 |
+
captions=data["captions"],
|
| 192 |
+
image_ids=data.get("image_ids"),
|
| 193 |
+
metadata=data.get("metadata", {}),
|
| 194 |
+
)
|
unimodal_sith/unisith.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UniSITH: Unimodal Semantic Inspection of Transformer Heads
|
| 3 |
+
|
| 4 |
+
Main analysis class that orchestrates:
|
| 5 |
+
1. Weight extraction (W_VO matrices from attention heads)
|
| 6 |
+
2. SVD decomposition (finding principal directions)
|
| 7 |
+
3. Projection to feature space
|
| 8 |
+
4. Concept attribution via COMP (matching to visual concepts)
|
| 9 |
+
5. Model editing (amplifying/suppressing concepts)
|
| 10 |
+
|
| 11 |
+
Key difference from original SITH:
|
| 12 |
+
- Works with ANY ViT (not just CLIP)
|
| 13 |
+
- Uses captioned images as concept pool (not text from ConceptNet)
|
| 14 |
+
- Captions provide human interpretability
|
| 15 |
+
- No cross-modal projection needed (same model encodes both the weights and concepts)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import numpy as np
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
from .weight_extraction import WeightExtractor
|
| 27 |
+
from .concept_pool import VisualConceptPool
|
| 28 |
+
from .comp import comp, comp_batch, top_k_selection
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class SingularVectorInterpretation:
|
| 33 |
+
"""Interpretation of a single singular vector."""
|
| 34 |
+
layer_idx: int
|
| 35 |
+
head_idx: int
|
| 36 |
+
sv_idx: int
|
| 37 |
+
singular_value: float
|
| 38 |
+
concepts: List[str] # Captions of matched concepts
|
| 39 |
+
concept_indices: List[int] # Indices into concept pool
|
| 40 |
+
coefficients: List[float] # COMP coefficients
|
| 41 |
+
fidelity: float # Cosine similarity between original and reconstruction
|
| 42 |
+
image_ids: Optional[List[int]] = None # IDs for retrieving original images
|
| 43 |
+
|
| 44 |
+
def to_dict(self) -> Dict:
|
| 45 |
+
return {
|
| 46 |
+
"layer": self.layer_idx,
|
| 47 |
+
"head": self.head_idx,
|
| 48 |
+
"sv_index": self.sv_idx,
|
| 49 |
+
"singular_value": self.singular_value,
|
| 50 |
+
"concepts": [
|
| 51 |
+
{"caption": c, "coefficient": w, "concept_idx": idx}
|
| 52 |
+
for c, w, idx in zip(self.concepts, self.coefficients, self.concept_indices)
|
| 53 |
+
],
|
| 54 |
+
"fidelity": self.fidelity,
|
| 55 |
+
"image_ids": self.image_ids,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def __repr__(self) -> str:
|
| 59 |
+
lines = [f"Layer {self.layer_idx}, Head {self.head_idx}, SV {self.sv_idx} "
|
| 60 |
+
f"(σ={self.singular_value:.4f}, fidelity={self.fidelity:.4f})"]
|
| 61 |
+
for c, w in zip(self.concepts, self.coefficients):
|
| 62 |
+
lines.append(f" [{w:.4f}] {c}")
|
| 63 |
+
return "\n".join(lines)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class HeadInterpretation:
|
| 68 |
+
"""Full interpretation of an attention head."""
|
| 69 |
+
layer_idx: int
|
| 70 |
+
head_idx: int
|
| 71 |
+
singular_vectors: List[SingularVectorInterpretation]
|
| 72 |
+
|
| 73 |
+
def to_dict(self) -> Dict:
|
| 74 |
+
return {
|
| 75 |
+
"layer": self.layer_idx,
|
| 76 |
+
"head": self.head_idx,
|
| 77 |
+
"singular_vectors": [sv.to_dict() for sv in self.singular_vectors],
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def __repr__(self) -> str:
|
| 81 |
+
lines = [f"=== Layer {self.layer_idx}, Head {self.head_idx} ==="]
|
| 82 |
+
for sv in self.singular_vectors:
|
| 83 |
+
lines.append(str(sv))
|
| 84 |
+
lines.append("")
|
| 85 |
+
return "\n".join(lines)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class UniSITH:
|
| 89 |
+
"""
|
| 90 |
+
Unimodal Semantic Inspection of Transformer Heads.
|
| 91 |
+
|
| 92 |
+
Analyzes the internal representations of ViT attention heads by:
|
| 93 |
+
1. Decomposing W_VO matrices via SVD
|
| 94 |
+
2. Projecting singular vectors to the model's feature space
|
| 95 |
+
3. Attributing visual concepts from a captioned image pool
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
model: torch.nn.Module,
|
| 101 |
+
architecture: str,
|
| 102 |
+
n_heads: int,
|
| 103 |
+
d_model: int,
|
| 104 |
+
concept_pool: VisualConceptPool,
|
| 105 |
+
device: str = "cpu",
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
model: Vision transformer model
|
| 110 |
+
architecture: One of "dinov2", "clip", "vit"
|
| 111 |
+
n_heads: Number of attention heads
|
| 112 |
+
d_model: Hidden dimension
|
| 113 |
+
concept_pool: Visual concept pool with embeddings and captions
|
| 114 |
+
device: Computation device
|
| 115 |
+
"""
|
| 116 |
+
self.model = model
|
| 117 |
+
self.architecture = architecture
|
| 118 |
+
self.device = device
|
| 119 |
+
self.concept_pool = concept_pool
|
| 120 |
+
|
| 121 |
+
self.extractor = WeightExtractor(model, architecture, n_heads, d_model)
|
| 122 |
+
self.n_heads = n_heads
|
| 123 |
+
self.d_model = d_model
|
| 124 |
+
|
| 125 |
+
# Precompute centered concept embeddings
|
| 126 |
+
self.centered_concepts, self.concept_mean = concept_pool.get_centered_embeddings()
|
| 127 |
+
self.centered_concepts = self.centered_concepts.to(device)
|
| 128 |
+
self.concept_mean = self.concept_mean.to(device)
|
| 129 |
+
|
| 130 |
+
def analyze_head(
|
| 131 |
+
self,
|
| 132 |
+
layer_idx: int,
|
| 133 |
+
head_idx: int,
|
| 134 |
+
n_singular_vectors: int = 5,
|
| 135 |
+
K: int = 5,
|
| 136 |
+
lambda_coh: float = 0.3,
|
| 137 |
+
method: str = "comp",
|
| 138 |
+
) -> HeadInterpretation:
|
| 139 |
+
"""
|
| 140 |
+
Analyze a single attention head: decompose its W_VO matrix and
|
| 141 |
+
interpret the top singular vectors.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
layer_idx: Transformer layer index
|
| 145 |
+
head_idx: Attention head index
|
| 146 |
+
n_singular_vectors: Number of top singular vectors to interpret
|
| 147 |
+
K: Number of concepts per singular vector
|
| 148 |
+
lambda_coh: COMP coherence weight
|
| 149 |
+
method: "comp" or "top_k"
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
HeadInterpretation with concept attributions for each singular vector
|
| 153 |
+
"""
|
| 154 |
+
# Step 1: Extract W_VO and decompose via SVD
|
| 155 |
+
W_VO_all = self.extractor.compute_WVO(layer_idx, fold_ln=True, project_ones=True)
|
| 156 |
+
W_VO_h = W_VO_all[head_idx] # [D, D]
|
| 157 |
+
|
| 158 |
+
U, sigma, Vt = self.extractor.svd_decompose(W_VO_h, top_k=n_singular_vectors)
|
| 159 |
+
# U: [D, n_sv], sigma: [n_sv], Vt: [n_sv, D]
|
| 160 |
+
|
| 161 |
+
# Step 2: Project right singular vectors to feature space
|
| 162 |
+
V_projected = self.extractor.project_to_feature_space(Vt) # [n_sv, d_out]
|
| 163 |
+
|
| 164 |
+
# Step 3: Center the projected vectors (analogous to modality gap correction)
|
| 165 |
+
V_centered = V_projected - self.concept_mean
|
| 166 |
+
V_centered = F.normalize(V_centered, dim=-1)
|
| 167 |
+
|
| 168 |
+
# Step 4: Attribute concepts via COMP (or top-k)
|
| 169 |
+
sv_interpretations = []
|
| 170 |
+
for i in range(n_singular_vectors):
|
| 171 |
+
v_hat = V_centered[i] # [d_out]
|
| 172 |
+
|
| 173 |
+
if method == "comp":
|
| 174 |
+
coeffs, support = comp(
|
| 175 |
+
v_hat, self.centered_concepts, K=K, lambda_coh=lambda_coh
|
| 176 |
+
)
|
| 177 |
+
elif method == "top_k":
|
| 178 |
+
coeffs, support = top_k_selection(
|
| 179 |
+
v_hat, self.centered_concepts, K=K
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"Unknown method: {method}")
|
| 183 |
+
|
| 184 |
+
# Extract concept captions and coefficients
|
| 185 |
+
concept_captions = [self.concept_pool.captions[idx] for idx in support]
|
| 186 |
+
concept_coeffs = [coeffs[idx].item() for idx in support]
|
| 187 |
+
concept_image_ids = None
|
| 188 |
+
if self.concept_pool.image_ids is not None:
|
| 189 |
+
concept_image_ids = [self.concept_pool.image_ids[idx] for idx in support]
|
| 190 |
+
|
| 191 |
+
# Compute fidelity: cosine similarity between original and reconstruction
|
| 192 |
+
reconstruction = torch.zeros_like(v_hat)
|
| 193 |
+
for idx, coeff in zip(support, concept_coeffs):
|
| 194 |
+
reconstruction += coeff * self.centered_concepts[idx]
|
| 195 |
+
fidelity = F.cosine_similarity(
|
| 196 |
+
v_hat.unsqueeze(0), reconstruction.unsqueeze(0)
|
| 197 |
+
).item()
|
| 198 |
+
|
| 199 |
+
sv_interpretations.append(SingularVectorInterpretation(
|
| 200 |
+
layer_idx=layer_idx,
|
| 201 |
+
head_idx=head_idx,
|
| 202 |
+
sv_idx=i,
|
| 203 |
+
singular_value=sigma[i].item(),
|
| 204 |
+
concepts=concept_captions,
|
| 205 |
+
concept_indices=support,
|
| 206 |
+
coefficients=concept_coeffs,
|
| 207 |
+
fidelity=fidelity,
|
| 208 |
+
image_ids=concept_image_ids,
|
| 209 |
+
))
|
| 210 |
+
|
| 211 |
+
return HeadInterpretation(
|
| 212 |
+
layer_idx=layer_idx,
|
| 213 |
+
head_idx=head_idx,
|
| 214 |
+
singular_vectors=sv_interpretations,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def analyze_layer(
|
| 218 |
+
self,
|
| 219 |
+
layer_idx: int,
|
| 220 |
+
n_singular_vectors: int = 5,
|
| 221 |
+
K: int = 5,
|
| 222 |
+
lambda_coh: float = 0.3,
|
| 223 |
+
method: str = "comp",
|
| 224 |
+
) -> List[HeadInterpretation]:
|
| 225 |
+
"""Analyze all heads in a layer."""
|
| 226 |
+
results = []
|
| 227 |
+
for h in range(self.n_heads):
|
| 228 |
+
print(f" Analyzing head {h}/{self.n_heads}...")
|
| 229 |
+
result = self.analyze_head(
|
| 230 |
+
layer_idx, h, n_singular_vectors, K, lambda_coh, method
|
| 231 |
+
)
|
| 232 |
+
results.append(result)
|
| 233 |
+
return results
|
| 234 |
+
|
| 235 |
+
def analyze_model(
|
| 236 |
+
self,
|
| 237 |
+
layers: Optional[List[int]] = None,
|
| 238 |
+
n_singular_vectors: int = 5,
|
| 239 |
+
K: int = 5,
|
| 240 |
+
lambda_coh: float = 0.3,
|
| 241 |
+
method: str = "comp",
|
| 242 |
+
) -> Dict[int, List[HeadInterpretation]]:
|
| 243 |
+
"""
|
| 244 |
+
Analyze multiple layers of the model.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
layers: List of layer indices. If None, analyzes last 4 layers.
|
| 248 |
+
n_singular_vectors: Number of top singular vectors per head
|
| 249 |
+
K: Concepts per singular vector
|
| 250 |
+
lambda_coh: COMP coherence weight
|
| 251 |
+
method: "comp" or "top_k"
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Dict mapping layer_idx -> list of HeadInterpretations
|
| 255 |
+
"""
|
| 256 |
+
if layers is None:
|
| 257 |
+
n_layers = self.extractor._get_num_layers()
|
| 258 |
+
layers = list(range(max(0, n_layers - 4), n_layers))
|
| 259 |
+
|
| 260 |
+
results = {}
|
| 261 |
+
for layer_idx in layers:
|
| 262 |
+
print(f"Analyzing layer {layer_idx}...")
|
| 263 |
+
results[layer_idx] = self.analyze_layer(
|
| 264 |
+
layer_idx, n_singular_vectors, K, lambda_coh, method
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return results
|
| 268 |
+
|
| 269 |
+
def edit_model(
|
| 270 |
+
self,
|
| 271 |
+
layer_idx: int,
|
| 272 |
+
head_idx: int,
|
| 273 |
+
sv_indices: List[int],
|
| 274 |
+
scale_factors: List[float],
|
| 275 |
+
) -> None:
|
| 276 |
+
"""
|
| 277 |
+
Edit the model by scaling specific singular values.
|
| 278 |
+
|
| 279 |
+
This enables:
|
| 280 |
+
- Suppressing concepts (scale -> 0): remove spurious features
|
| 281 |
+
- Amplifying concepts (scale > 1): enhance task-relevant features
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
layer_idx: Layer to edit
|
| 285 |
+
head_idx: Head to edit
|
| 286 |
+
sv_indices: Indices of singular vectors to modify
|
| 287 |
+
scale_factors: Scaling factor for each (0 = suppress, >1 = amplify)
|
| 288 |
+
"""
|
| 289 |
+
# Get original W_VO
|
| 290 |
+
W_VO_all = self.extractor.compute_WVO(layer_idx, fold_ln=False, project_ones=False)
|
| 291 |
+
W_VO_h = W_VO_all[head_idx]
|
| 292 |
+
|
| 293 |
+
# SVD decompose
|
| 294 |
+
U, sigma, Vt = torch.linalg.svd(W_VO_h, full_matrices=False)
|
| 295 |
+
|
| 296 |
+
# Scale selected singular values
|
| 297 |
+
for sv_idx, scale in zip(sv_indices, scale_factors):
|
| 298 |
+
sigma[sv_idx] *= scale
|
| 299 |
+
|
| 300 |
+
# Reconstruct
|
| 301 |
+
W_VO_edited = U @ torch.diag(sigma) @ Vt
|
| 302 |
+
|
| 303 |
+
# Write back to the model
|
| 304 |
+
# W_VO = W_V^T @ W_O^T, so we need to update W_V and W_O
|
| 305 |
+
# Simplest approach: perform low-rank update on W_V
|
| 306 |
+
# Since W_VO = W_V_h^T @ W_O_h^T and we want W_VO_edited,
|
| 307 |
+
# we can set W_V_h_new such that W_V_h_new^T @ W_O_h^T = W_VO_edited
|
| 308 |
+
# This is: W_V_h_new^T = W_VO_edited @ (W_O_h^T)^(-1)
|
| 309 |
+
# But W_O_h is rank d_h, so not invertible in D x D space.
|
| 310 |
+
#
|
| 311 |
+
# Alternative: directly edit the singular values in the SVD of the
|
| 312 |
+
# original (non-folded) W_VO by identifying correspondence.
|
| 313 |
+
#
|
| 314 |
+
# For simplicity, we reconstruct W_VO and decompose into W_V and W_O
|
| 315 |
+
# via the original head dimension factorization.
|
| 316 |
+
self._write_WVO_to_model(layer_idx, head_idx, W_VO_edited)
|
| 317 |
+
|
| 318 |
+
def _write_WVO_to_model(
|
| 319 |
+
self,
|
| 320 |
+
layer_idx: int,
|
| 321 |
+
head_idx: int,
|
| 322 |
+
W_VO_edited: torch.Tensor,
|
| 323 |
+
):
|
| 324 |
+
"""
|
| 325 |
+
Write an edited W_VO back to the model weights.
|
| 326 |
+
|
| 327 |
+
Since W_VO = W_V_h^T @ W_O_h^T and has rank d_h, we can use SVD
|
| 328 |
+
to factorize W_VO_edited into new W_V_h and W_O_h.
|
| 329 |
+
|
| 330 |
+
W_VO_edited = U_e @ S_e @ V_e^T
|
| 331 |
+
Take top-d_h components:
|
| 332 |
+
W_V_h_new^T = U_e[:, :d_h] @ sqrt(S_e[:d_h])
|
| 333 |
+
W_O_h_new^T = sqrt(S_e[:d_h]) @ V_e[:d_h, :]
|
| 334 |
+
"""
|
| 335 |
+
d_h = self.extractor.head_dim
|
| 336 |
+
|
| 337 |
+
# SVD of edited W_VO
|
| 338 |
+
U_e, S_e, Vt_e = torch.linalg.svd(W_VO_edited, full_matrices=False)
|
| 339 |
+
|
| 340 |
+
# Keep top d_h components
|
| 341 |
+
sqrt_S = torch.sqrt(S_e[:d_h])
|
| 342 |
+
|
| 343 |
+
# New W_V_h^T = U_e[:, :d_h] @ diag(sqrt_S) => shape [D, d_h]
|
| 344 |
+
# So W_V_h = (U_e[:, :d_h] @ diag(sqrt_S))^T = diag(sqrt_S) @ U_e[:, :d_h]^T
|
| 345 |
+
# => W_V_h shape [d_h, D]
|
| 346 |
+
W_V_h_new = (sqrt_S.unsqueeze(1) * U_e[:, :d_h].T) # [d_h, D]
|
| 347 |
+
|
| 348 |
+
# New W_O_h^T = diag(sqrt_S) @ Vt_e[:d_h, :] => shape [d_h, D]
|
| 349 |
+
# So W_O_h = (diag(sqrt_S) @ Vt_e[:d_h, :])^T = Vt_e[:d_h, :]^T @ diag(sqrt_S)
|
| 350 |
+
# => W_O_h shape [D, d_h]
|
| 351 |
+
W_O_h_new = (Vt_e[:d_h, :].T * sqrt_S.unsqueeze(0)) # [D, d_h]
|
| 352 |
+
|
| 353 |
+
# Write W_V_h back
|
| 354 |
+
_, _, W_V = self.extractor._get_qkv_weights(layer_idx)
|
| 355 |
+
W_O = self.extractor._get_output_weight(layer_idx)
|
| 356 |
+
|
| 357 |
+
h = head_idx
|
| 358 |
+
d_h = self.extractor.head_dim
|
| 359 |
+
|
| 360 |
+
# W_V is [d_model, d_model], head h occupies rows [h*d_h : (h+1)*d_h]
|
| 361 |
+
W_V[h * d_h : (h + 1) * d_h, :] = W_V_h_new
|
| 362 |
+
|
| 363 |
+
# W_O is [d_model, d_model], head h occupies columns [h*d_h : (h+1)*d_h]
|
| 364 |
+
W_O[:, h * d_h : (h + 1) * d_h] = W_O_h_new
|
| 365 |
+
|
| 366 |
+
def find_concept_heads(
|
| 367 |
+
self,
|
| 368 |
+
target_concepts: List[str],
|
| 369 |
+
concept_embeddings: torch.Tensor,
|
| 370 |
+
layers: Optional[List[int]] = None,
|
| 371 |
+
n_singular_vectors: int = 10,
|
| 372 |
+
K: int = 5,
|
| 373 |
+
lambda_coh: float = 0.3,
|
| 374 |
+
threshold: float = 0.3,
|
| 375 |
+
) -> List[Dict]:
|
| 376 |
+
"""
|
| 377 |
+
Find attention heads that encode specific concepts.
|
| 378 |
+
|
| 379 |
+
Useful for targeted model editing: find which heads encode
|
| 380 |
+
"background" features, "texture" features, etc.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
target_concepts: List of target concept descriptions
|
| 384 |
+
concept_embeddings: [n_targets, d] embeddings of target concepts
|
| 385 |
+
layers: Layers to search
|
| 386 |
+
n_singular_vectors: SVs per head to check
|
| 387 |
+
K: Concepts per SV
|
| 388 |
+
lambda_coh: COMP coherence weight
|
| 389 |
+
threshold: Minimum similarity to consider a match
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List of dicts with head locations and matching info
|
| 393 |
+
"""
|
| 394 |
+
results = self.analyze_model(
|
| 395 |
+
layers=layers,
|
| 396 |
+
n_singular_vectors=n_singular_vectors,
|
| 397 |
+
K=K,
|
| 398 |
+
lambda_coh=lambda_coh,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
matches = []
|
| 402 |
+
concept_embeddings = F.normalize(concept_embeddings.to(self.device), dim=-1)
|
| 403 |
+
|
| 404 |
+
for layer_idx, heads in results.items():
|
| 405 |
+
for head_interp in heads:
|
| 406 |
+
for sv_interp in head_interp.singular_vectors:
|
| 407 |
+
# Check if any of the attributed concepts match targets
|
| 408 |
+
for ci, concept_idx in enumerate(sv_interp.concept_indices):
|
| 409 |
+
concept_emb = self.centered_concepts[concept_idx]
|
| 410 |
+
sims = (concept_embeddings @ concept_emb).tolist()
|
| 411 |
+
max_sim = max(sims)
|
| 412 |
+
if max_sim > threshold:
|
| 413 |
+
matches.append({
|
| 414 |
+
"layer": layer_idx,
|
| 415 |
+
"head": head_interp.head_idx,
|
| 416 |
+
"sv_index": sv_interp.sv_idx,
|
| 417 |
+
"concept": sv_interp.concepts[ci],
|
| 418 |
+
"coefficient": sv_interp.coefficients[ci],
|
| 419 |
+
"target_similarity": max_sim,
|
| 420 |
+
"singular_value": sv_interp.singular_value,
|
| 421 |
+
})
|
| 422 |
+
|
| 423 |
+
# Sort by relevance (target_similarity * singular_value * coefficient)
|
| 424 |
+
matches.sort(
|
| 425 |
+
key=lambda x: x["target_similarity"] * x["singular_value"] * x["coefficient"],
|
| 426 |
+
reverse=True,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
return matches
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def save_results(
|
| 433 |
+
results: Dict[int, List[HeadInterpretation]],
|
| 434 |
+
path: str,
|
| 435 |
+
):
|
| 436 |
+
"""Save analysis results to JSON."""
|
| 437 |
+
serialized = {}
|
| 438 |
+
for layer_idx, heads in results.items():
|
| 439 |
+
serialized[str(layer_idx)] = [h.to_dict() for h in heads]
|
| 440 |
+
|
| 441 |
+
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
| 442 |
+
with open(path, "w") as f:
|
| 443 |
+
json.dump(serialized, f, indent=2)
|
| 444 |
+
print(f"Results saved to {path}")
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def load_results(path: str) -> Dict:
|
| 448 |
+
"""Load analysis results from JSON."""
|
| 449 |
+
with open(path) as f:
|
| 450 |
+
return json.load(f)
|
unimodal_sith/weight_extraction.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Weight extraction utilities for various ViT architectures.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- DINOv2 (facebook/dinov2-*)
|
| 6 |
+
- CLIP ViT (openai/clip-vit-* via HuggingFace transformers)
|
| 7 |
+
- Any HuggingFace ViT (google/vit-*)
|
| 8 |
+
|
| 9 |
+
For each architecture, extracts:
|
| 10 |
+
- W_V (value projection) and W_O (output projection) per attention head
|
| 11 |
+
- W_VO = W_V^T @ W_O^T (the value-output matrix, as in SITH)
|
| 12 |
+
- LayerNorm parameters for folding
|
| 13 |
+
- Final projection matrix W_p (if present)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from typing import Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def fold_layernorm_into_weights(
|
| 22 |
+
W: torch.Tensor,
|
| 23 |
+
ln_weight: torch.Tensor,
|
| 24 |
+
ln_bias: torch.Tensor,
|
| 25 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 26 |
+
"""
|
| 27 |
+
Fold LayerNorm affine parameters into a weight matrix.
|
| 28 |
+
|
| 29 |
+
Given: LN(x) = (x - mean) / std * w + b
|
| 30 |
+
The affine part: x_affine = diag(w) @ x + b
|
| 31 |
+
Folded: W' = diag(w) @ W, b_proj = W^T @ b (absorbed into bias)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
W: Weight matrix [out_dim, in_dim] (applied as x @ W^T)
|
| 35 |
+
ln_weight: LayerNorm weight [in_dim]
|
| 36 |
+
ln_bias: LayerNorm bias [in_dim]
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
W_folded: [out_dim, in_dim]
|
| 40 |
+
b_folded: [out_dim]
|
| 41 |
+
"""
|
| 42 |
+
# W' = W @ diag(ln_weight) (since W acts as x @ W^T, LN acts on x first)
|
| 43 |
+
# Actually: LN(x) @ W^T = (diag(w) @ x + b) @ W^T = x @ diag(w) @ W^T + b @ W^T
|
| 44 |
+
# So W_folded = W @ diag(w) when W is [out, in] and applied as x @ W^T
|
| 45 |
+
# But in PyTorch Linear: y = x @ W^T + bias
|
| 46 |
+
# LN(x) @ W^T = (x * w + b) @ W^T = x @ diag(w) @ W^T + b @ W^T
|
| 47 |
+
# => W_folded^T = diag(w) @ W^T => W_folded = W @ diag(w)... no:
|
| 48 |
+
# W_folded = diag(w) @ W when W is [out, in] and we want W_folded to replace W
|
| 49 |
+
# such that x @ W_folded^T = (x * w) @ W^T = x @ (diag(w) @ W)^T = x @ W^T @ diag(w)
|
| 50 |
+
# Hmm, let's be precise:
|
| 51 |
+
# y = (x * w + b) @ W^T = (x @ diag(w)) @ W^T + b @ W^T
|
| 52 |
+
# = x @ (W @ diag(w))^T + b @ W^T
|
| 53 |
+
# So: W_folded = W @ diag(w) ... no wait:
|
| 54 |
+
# x @ (W @ diag(w))^T = x @ diag(w)^T @ W^T = x @ diag(w) @ W^T ✓
|
| 55 |
+
# Hmm, diag(w)^T = diag(w), so: W_folded such that x @ W_folded^T = x @ diag(w) @ W^T
|
| 56 |
+
# => W_folded^T = diag(w) @ W^T => W_folded = W @ diag(w)
|
| 57 |
+
|
| 58 |
+
# Actually more carefully:
|
| 59 |
+
# For W [out_dim, in_dim], y = x W^T
|
| 60 |
+
# After LN: y' = LN(x) W^T = (x*w + b) W^T
|
| 61 |
+
# = (diag(w) x^T)^T W^T + b W^T ... nope, element-wise
|
| 62 |
+
# x*w is element-wise: x*w = x @ diag(w) (treating x as [1, in_dim])
|
| 63 |
+
# So y' = x @ diag(w) @ W^T + b @ W^T
|
| 64 |
+
# New W_folded: W_folded = W @ diag(w) (then x @ W_folded^T = x @ diag(w) @ W^T ✓)
|
| 65 |
+
# Wait: (W @ diag(w))^T = diag(w) @ W^T, so x @ (W @ diag(w))^T = x @ diag(w) @ W^T ✓
|
| 66 |
+
|
| 67 |
+
W_folded = W * ln_weight.unsqueeze(0) # Broadcast: [out, in] * [1, in] = [out, in]
|
| 68 |
+
b_folded = ln_bias @ W.t() # [in] @ [in, out] = [out]
|
| 69 |
+
|
| 70 |
+
return W_folded, b_folded
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def project_out_ones(W: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
"""
|
| 75 |
+
Project weight matrix columns onto the subspace orthogonal to the all-ones direction.
|
| 76 |
+
This accounts for the centering operation of LayerNorm.
|
| 77 |
+
|
| 78 |
+
For a matrix W [D, D], we subtract the mean of each column from itself.
|
| 79 |
+
Equivalently: W_proj = W - (1/D) * ones @ ones^T @ W
|
| 80 |
+
"""
|
| 81 |
+
D = W.shape[0]
|
| 82 |
+
col_means = W.mean(dim=0, keepdim=True) # [1, D]
|
| 83 |
+
W_proj = W - col_means
|
| 84 |
+
return W_proj
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class WeightExtractor:
|
| 88 |
+
"""
|
| 89 |
+
Extracts and processes attention head weights for SITH analysis.
|
| 90 |
+
Architecture-agnostic: supports DINOv2, CLIP ViT, standard ViT.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
SUPPORTED_ARCHITECTURES = ["dinov2", "clip", "vit"]
|
| 94 |
+
|
| 95 |
+
def __init__(self, model: nn.Module, architecture: str, n_heads: int, d_model: int):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
model: The loaded model (HuggingFace transformers model)
|
| 99 |
+
architecture: One of "dinov2", "clip", "vit"
|
| 100 |
+
n_heads: Number of attention heads
|
| 101 |
+
d_model: Hidden dimension
|
| 102 |
+
"""
|
| 103 |
+
assert architecture in self.SUPPORTED_ARCHITECTURES, \
|
| 104 |
+
f"Unsupported architecture: {architecture}. Use one of {self.SUPPORTED_ARCHITECTURES}"
|
| 105 |
+
|
| 106 |
+
self.model = model
|
| 107 |
+
self.architecture = architecture
|
| 108 |
+
self.n_heads = n_heads
|
| 109 |
+
self.d_model = d_model
|
| 110 |
+
self.head_dim = d_model // n_heads
|
| 111 |
+
|
| 112 |
+
def _get_layer(self, layer_idx: int):
|
| 113 |
+
"""Get the transformer layer by index."""
|
| 114 |
+
if self.architecture == "dinov2":
|
| 115 |
+
return self.model.encoder.layer[layer_idx]
|
| 116 |
+
elif self.architecture == "clip":
|
| 117 |
+
return self.model.vision_model.encoder.layers[layer_idx]
|
| 118 |
+
elif self.architecture == "vit":
|
| 119 |
+
# AutoModel for ViT doesn't have .vit prefix
|
| 120 |
+
if hasattr(self.model, 'vit'):
|
| 121 |
+
return self.model.vit.encoder.layer[layer_idx]
|
| 122 |
+
else:
|
| 123 |
+
return self.model.encoder.layer[layer_idx]
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unknown architecture: {self.architecture}")
|
| 126 |
+
|
| 127 |
+
def _get_num_layers(self) -> int:
|
| 128 |
+
"""Get total number of transformer layers."""
|
| 129 |
+
if self.architecture == "dinov2":
|
| 130 |
+
return len(self.model.encoder.layer)
|
| 131 |
+
elif self.architecture == "clip":
|
| 132 |
+
return len(self.model.vision_model.encoder.layers)
|
| 133 |
+
elif self.architecture == "vit":
|
| 134 |
+
if hasattr(self.model, 'vit'):
|
| 135 |
+
return len(self.model.vit.encoder.layer)
|
| 136 |
+
else:
|
| 137 |
+
return len(self.model.encoder.layer)
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Unknown architecture: {self.architecture}")
|
| 140 |
+
|
| 141 |
+
def _get_qkv_weights(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 142 |
+
"""Extract Q, K, V weight matrices from a layer."""
|
| 143 |
+
layer = self._get_layer(layer_idx)
|
| 144 |
+
|
| 145 |
+
if self.architecture == "dinov2":
|
| 146 |
+
attn = layer.attention.attention
|
| 147 |
+
W_Q = attn.query.weight.data # [d_model, d_model]
|
| 148 |
+
W_K = attn.key.weight.data
|
| 149 |
+
W_V = attn.value.weight.data
|
| 150 |
+
elif self.architecture == "clip":
|
| 151 |
+
attn = layer.self_attn
|
| 152 |
+
W_Q = attn.q_proj.weight.data
|
| 153 |
+
W_K = attn.k_proj.weight.data
|
| 154 |
+
W_V = attn.v_proj.weight.data
|
| 155 |
+
elif self.architecture == "vit":
|
| 156 |
+
attn = layer.attention.attention
|
| 157 |
+
W_Q = attn.query.weight.data
|
| 158 |
+
W_K = attn.key.weight.data
|
| 159 |
+
W_V = attn.value.weight.data
|
| 160 |
+
|
| 161 |
+
return W_Q, W_K, W_V
|
| 162 |
+
|
| 163 |
+
def _get_output_weight(self, layer_idx: int) -> torch.Tensor:
|
| 164 |
+
"""Extract output projection weight matrix."""
|
| 165 |
+
layer = self._get_layer(layer_idx)
|
| 166 |
+
|
| 167 |
+
if self.architecture == "dinov2":
|
| 168 |
+
return layer.attention.output.dense.weight.data # [d_model, d_model]
|
| 169 |
+
elif self.architecture == "clip":
|
| 170 |
+
return layer.self_attn.out_proj.weight.data
|
| 171 |
+
elif self.architecture == "vit":
|
| 172 |
+
return layer.attention.output.dense.weight.data
|
| 173 |
+
|
| 174 |
+
def _get_pre_attn_layernorm(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 175 |
+
"""Get the LayerNorm weight and bias that precedes the attention block."""
|
| 176 |
+
layer = self._get_layer(layer_idx)
|
| 177 |
+
|
| 178 |
+
if self.architecture == "dinov2":
|
| 179 |
+
ln = layer.norm1
|
| 180 |
+
elif self.architecture == "clip":
|
| 181 |
+
ln = layer.layer_norm1
|
| 182 |
+
elif self.architecture == "vit":
|
| 183 |
+
ln = layer.layernorm_before
|
| 184 |
+
|
| 185 |
+
return ln.weight.data, ln.bias.data
|
| 186 |
+
|
| 187 |
+
def _get_final_layernorm(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
| 188 |
+
"""Get the final LayerNorm (applied before projection, if present)."""
|
| 189 |
+
if self.architecture == "dinov2":
|
| 190 |
+
ln = self.model.layernorm
|
| 191 |
+
elif self.architecture == "clip":
|
| 192 |
+
ln = self.model.vision_model.post_layernorm
|
| 193 |
+
elif self.architecture == "vit":
|
| 194 |
+
if hasattr(self.model, 'vit'):
|
| 195 |
+
ln = self.model.vit.layernorm
|
| 196 |
+
else:
|
| 197 |
+
ln = self.model.layernorm
|
| 198 |
+
|
| 199 |
+
return ln.weight.data, ln.bias.data
|
| 200 |
+
|
| 201 |
+
def _get_projection_matrix(self) -> Optional[torch.Tensor]:
|
| 202 |
+
"""Get the final projection matrix W_p (maps hidden dim to output dim)."""
|
| 203 |
+
if self.architecture == "clip":
|
| 204 |
+
# CLIP has a visual projection: [proj_dim, d_model]
|
| 205 |
+
# Applied as: features = cls_token @ W_p^T
|
| 206 |
+
try:
|
| 207 |
+
W_p = self.model.visual_projection.weight.data # [proj_dim, d_model]
|
| 208 |
+
return W_p.t() # Return as [d_model, proj_dim]
|
| 209 |
+
except AttributeError:
|
| 210 |
+
return None
|
| 211 |
+
elif self.architecture == "dinov2":
|
| 212 |
+
# DINOv2 has no projection matrix
|
| 213 |
+
return None
|
| 214 |
+
elif self.architecture == "vit":
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
def _get_layerscale(self, layer_idx: int) -> Optional[torch.Tensor]:
|
| 218 |
+
"""Get LayerScale parameter (DINOv2 specific)."""
|
| 219 |
+
if self.architecture == "dinov2":
|
| 220 |
+
layer = self._get_layer(layer_idx)
|
| 221 |
+
try:
|
| 222 |
+
return layer.layer_scale1.lambda1.data # [d_model]
|
| 223 |
+
except AttributeError:
|
| 224 |
+
return None
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
def compute_WVO(
|
| 228 |
+
self,
|
| 229 |
+
layer_idx: int,
|
| 230 |
+
fold_ln: bool = True,
|
| 231 |
+
project_ones: bool = True,
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
Compute the Value-Output (VO) weight matrix for all heads in a layer.
|
| 235 |
+
|
| 236 |
+
W_VO_h = W_V_h^T @ W_O_h^T where:
|
| 237 |
+
- W_V_h is [head_dim, d_model] (head h's slice of W_V)
|
| 238 |
+
- W_O_h is [d_model, head_dim] (head h's slice of W_O)
|
| 239 |
+
- W_VO_h is [d_model, d_model] (rank head_dim)
|
| 240 |
+
|
| 241 |
+
Following the paper:
|
| 242 |
+
H'_h(X) = softmax(...) @ X @ W_VO_h
|
| 243 |
+
where W_VO_h = W_V_h @ W_O_h, with W_V [d_model, d_model] and W_O [d_model, d_model]
|
| 244 |
+
|
| 245 |
+
Actually from Eq. (4): MHA(X) = sum_h A^h @ X @ W_VO^h
|
| 246 |
+
where W_VO^h = W_V^h @ W_O^h
|
| 247 |
+
W_V^h: [D, d_h], W_O^h: [d_h, D] => W_VO^h: [D, D]
|
| 248 |
+
|
| 249 |
+
In PyTorch Linear(in, out): weight is [out, in], applied as x @ W^T
|
| 250 |
+
So W_V.weight is [d_model, d_model] applied as x @ W_V^T
|
| 251 |
+
But per head: W_V_h is rows [h*d_h : (h+1)*d_h] of W_V.weight => [d_h, D]
|
| 252 |
+
Applied as x @ W_V_h^T => [N, D] @ [D, d_h] => [N, d_h]
|
| 253 |
+
|
| 254 |
+
W_O.weight is [D, D] applied as h_out @ W_O^T
|
| 255 |
+
Per head: W_O_h is columns [h*d_h : (h+1)*d_h] of W_O.weight => [D, d_h]
|
| 256 |
+
So h_out_h @ W_O_h^T => [N, d_h] @ [d_h, D] => [N, D] ... wait
|
| 257 |
+
Actually W_O.weight [D, D]: output = concat(h1..hH) @ W_O^T
|
| 258 |
+
concat has shape [N, H*d_h] = [N, D]
|
| 259 |
+
But per-head: W_O_h = W_O.weight[:, h*d_h:(h+1)*d_h] => [D, d_h]
|
| 260 |
+
Applied: h_out_h [N, d_h] @ W_O_h^T [d_h, D] => [N, D]
|
| 261 |
+
|
| 262 |
+
Full per-head: output_h = A_h @ X @ W_V_h^T @ W_O_h^T
|
| 263 |
+
W_VO_h = W_V_h^T @ W_O_h^T = [D, d_h] @ [d_h, D] = [D, D]
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
W_VO: [n_heads, d_model, d_model]
|
| 267 |
+
"""
|
| 268 |
+
_, _, W_V = self._get_qkv_weights(layer_idx)
|
| 269 |
+
W_O = self._get_output_weight(layer_idx)
|
| 270 |
+
|
| 271 |
+
# Optionally fold LayerNorm into W_V
|
| 272 |
+
if fold_ln:
|
| 273 |
+
ln_weight, ln_bias = self._get_pre_attn_layernorm(layer_idx)
|
| 274 |
+
# W_V is [d_model, d_model], LN acts on input x before attention
|
| 275 |
+
# x_ln = x * w + b (element-wise), then x_ln @ W_V^T
|
| 276 |
+
# = (x @ diag(w)) @ W_V^T + b @ W_V^T
|
| 277 |
+
# Folded W_V: such that x @ W_V_folded^T = x @ diag(w) @ W_V^T
|
| 278 |
+
# => W_V_folded^T = diag(w) @ W_V^T => W_V_folded = W_V @ diag(w)
|
| 279 |
+
# But W_V is [out=d_model, in=d_model] and diag(w) acts on input dim
|
| 280 |
+
W_V = W_V * ln_weight.unsqueeze(0) # [d_model, d_model] * [1, d_model]
|
| 281 |
+
|
| 282 |
+
# Fold LayerScale if present (DINOv2)
|
| 283 |
+
ls = self._get_layerscale(layer_idx)
|
| 284 |
+
if ls is not None:
|
| 285 |
+
# LayerScale is applied after attention output: output = ls * attn_output
|
| 286 |
+
# So W_O_effective = diag(ls) @ W_O
|
| 287 |
+
# In our notation: output_h = A_h @ X @ W_V_h^T @ W_O_h^T * ls
|
| 288 |
+
# = A_h @ X @ W_V_h^T @ (diag(ls) @ W_O_h)^T
|
| 289 |
+
# W_O is [D, D], ls is [D]
|
| 290 |
+
W_O = W_O * ls.unsqueeze(1) # [D, D] * [D, 1] = [D, D]
|
| 291 |
+
|
| 292 |
+
# Split into per-head matrices
|
| 293 |
+
# W_V: [d_model, d_model] -> W_V_h: [d_h, d_model] for head h
|
| 294 |
+
W_V_per_head = W_V.view(self.n_heads, self.head_dim, self.d_model) # [H, d_h, D]
|
| 295 |
+
|
| 296 |
+
# W_O: [d_model, d_model] -> W_O_h: [d_model, d_h] for head h
|
| 297 |
+
# W_O[:, h*d_h:(h+1)*d_h] => reshaped
|
| 298 |
+
W_O_per_head = W_O.view(self.d_model, self.n_heads, self.head_dim) # [D, H, d_h]
|
| 299 |
+
W_O_per_head = W_O_per_head.permute(1, 0, 2) # [H, D, d_h]
|
| 300 |
+
|
| 301 |
+
# W_VO_h = W_V_h^T @ W_O_h^T = [D, d_h] @ [d_h, D] = [D, D]
|
| 302 |
+
# W_V_h^T: [D, d_h], W_O_h^T: [d_h, D]
|
| 303 |
+
W_VO = torch.bmm(
|
| 304 |
+
W_V_per_head.transpose(1, 2), # [H, D, d_h]
|
| 305 |
+
W_O_per_head.transpose(1, 2), # [H, d_h, D]
|
| 306 |
+
) # [H, D, D]
|
| 307 |
+
|
| 308 |
+
# Project out the all-ones direction (centering from LN)
|
| 309 |
+
if project_ones:
|
| 310 |
+
for h in range(self.n_heads):
|
| 311 |
+
W_VO[h] = project_out_ones(W_VO[h])
|
| 312 |
+
|
| 313 |
+
return W_VO
|
| 314 |
+
|
| 315 |
+
def svd_decompose(
|
| 316 |
+
self,
|
| 317 |
+
W_VO_h: torch.Tensor,
|
| 318 |
+
top_k: Optional[int] = None,
|
| 319 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 320 |
+
"""
|
| 321 |
+
Decompose a per-head W_VO matrix via SVD.
|
| 322 |
+
|
| 323 |
+
W_VO = U @ diag(sigma) @ V^T
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
W_VO_h: [d_model, d_model] VO matrix for a single head
|
| 327 |
+
top_k: If set, only return top-k singular vectors
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
U: [d_model, r] left singular vectors (reading directions)
|
| 331 |
+
sigma: [r] singular values
|
| 332 |
+
Vt: [r, d_model] right singular vectors (writing directions)
|
| 333 |
+
"""
|
| 334 |
+
U, sigma, Vt = torch.linalg.svd(W_VO_h, full_matrices=False)
|
| 335 |
+
|
| 336 |
+
if top_k is not None:
|
| 337 |
+
U = U[:, :top_k]
|
| 338 |
+
sigma = sigma[:top_k]
|
| 339 |
+
Vt = Vt[:top_k, :]
|
| 340 |
+
|
| 341 |
+
return U, sigma, Vt
|
| 342 |
+
|
| 343 |
+
def project_to_feature_space(
|
| 344 |
+
self,
|
| 345 |
+
vectors: torch.Tensor,
|
| 346 |
+
) -> torch.Tensor:
|
| 347 |
+
"""
|
| 348 |
+
Project singular vectors from the residual stream to the model's output feature space.
|
| 349 |
+
|
| 350 |
+
For CLIP: apply final LN then W_p projection
|
| 351 |
+
For DINOv2/ViT: apply final LN (no projection matrix)
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
vectors: [n, d_model] singular vectors in residual stream space
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
projected: [n, d_out] vectors in the output feature space, L2-normalized
|
| 358 |
+
"""
|
| 359 |
+
# Get final LayerNorm
|
| 360 |
+
ln_w, ln_b = self._get_final_layernorm()
|
| 361 |
+
|
| 362 |
+
# Apply LN affine transformation (without the data-dependent normalization)
|
| 363 |
+
# Since these are abstract directions (not activations), we just apply the affine part
|
| 364 |
+
# v_ln = v * ln_weight + ln_bias
|
| 365 |
+
vectors_ln = vectors * ln_w.unsqueeze(0) + ln_b.unsqueeze(0)
|
| 366 |
+
|
| 367 |
+
# Apply projection if present
|
| 368 |
+
W_p = self._get_projection_matrix()
|
| 369 |
+
if W_p is not None:
|
| 370 |
+
# W_p is [d_model, proj_dim]
|
| 371 |
+
vectors_proj = vectors_ln @ W_p # [n, proj_dim]
|
| 372 |
+
else:
|
| 373 |
+
vectors_proj = vectors_ln
|
| 374 |
+
|
| 375 |
+
# L2 normalize
|
| 376 |
+
vectors_proj = torch.nn.functional.normalize(vectors_proj, dim=-1)
|
| 377 |
+
|
| 378 |
+
return vectors_proj
|