lorenzovaquero commited on
Commit
24e41d7
·
verified ·
1 Parent(s): 2f74fc4

Add CLI runner script

Browse files
Files changed (1) hide show
  1. run_unisith.py +270 -0
run_unisith.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ UniSITH Demo: Analyze a DINOv2 model using captioned images as concept pool.
4
+
5
+ This script demonstrates the full UniSITH pipeline:
6
+ 1. Load a unimodal ViT model (DINOv2-large)
7
+ 2. Build a visual concept pool from Recap-COCO-30K
8
+ 3. Analyze attention heads via SVD + COMP
9
+ 4. Display human-interpretable concept attributions
10
+
11
+ Usage:
12
+ python run_unisith.py --model facebook/dinov2-large --max-concepts 1000
13
+ python run_unisith.py --model openai/clip-vit-large-patch14 --architecture clip
14
+ """
15
+
16
+ import argparse
17
+ import torch
18
+ import os
19
+ import sys
20
+ import json
21
+ from transformers import AutoModel, AutoProcessor, AutoImageProcessor
22
+ from transformers import CLIPModel, CLIPProcessor
23
+ from datasets import load_dataset
24
+
25
+ # Add parent dir to path
26
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from unimodal_sith.concept_pool import VisualConceptPool
29
+ from unimodal_sith.unisith import UniSITH
30
+
31
+
32
+ # Model configurations
33
+ MODEL_CONFIGS = {
34
+ "facebook/dinov2-large": {
35
+ "architecture": "dinov2",
36
+ "n_heads": 16,
37
+ "d_model": 1024,
38
+ },
39
+ "facebook/dinov2-base": {
40
+ "architecture": "dinov2",
41
+ "n_heads": 12,
42
+ "d_model": 768,
43
+ },
44
+ "facebook/dinov2-small": {
45
+ "architecture": "dinov2",
46
+ "n_heads": 6,
47
+ "d_model": 384,
48
+ },
49
+ "openai/clip-vit-large-patch14": {
50
+ "architecture": "clip",
51
+ "n_heads": 16,
52
+ "d_model": 1024,
53
+ },
54
+ "openai/clip-vit-base-patch16": {
55
+ "architecture": "clip",
56
+ "n_heads": 12,
57
+ "d_model": 768,
58
+ },
59
+ "google/vit-large-patch16-224": {
60
+ "architecture": "vit",
61
+ "n_heads": 16,
62
+ "d_model": 1024,
63
+ },
64
+ "google/vit-base-patch16-224": {
65
+ "architecture": "vit",
66
+ "n_heads": 12,
67
+ "d_model": 768,
68
+ },
69
+ }
70
+
71
+
72
+ def load_model_and_processor(model_name: str, architecture: str):
73
+ """Load model and processor based on architecture type."""
74
+ print(f"Loading model: {model_name}")
75
+
76
+ if architecture == "clip":
77
+ model = CLIPModel.from_pretrained(model_name)
78
+ processor = CLIPProcessor.from_pretrained(model_name)
79
+ elif architecture == "dinov2":
80
+ model = AutoModel.from_pretrained(model_name)
81
+ processor = AutoImageProcessor.from_pretrained(model_name)
82
+ elif architecture == "vit":
83
+ model = AutoModel.from_pretrained(model_name)
84
+ processor = AutoImageProcessor.from_pretrained(model_name)
85
+ else:
86
+ raise ValueError(f"Unknown architecture: {architecture}")
87
+
88
+ model.eval()
89
+ return model, processor
90
+
91
+
92
+ def build_concept_pool(
93
+ model,
94
+ processor,
95
+ architecture: str,
96
+ max_concepts: int = 1000,
97
+ cache_path: str = None,
98
+ device: str = "cpu",
99
+ ):
100
+ """Build visual concept pool from Recap-COCO-30K."""
101
+ print(f"Building concept pool with {max_concepts} concepts...")
102
+
103
+ # Load dataset
104
+ dataset = load_dataset("UCSC-VLAA/Recap-COCO-30K", split="train")
105
+
106
+ pool = VisualConceptPool.from_dataset(
107
+ dataset=dataset,
108
+ model=model,
109
+ processor=processor,
110
+ architecture=architecture,
111
+ image_column="image",
112
+ caption_column="caption", # Short COCO captions for readability
113
+ image_id_column="image_id",
114
+ batch_size=32,
115
+ max_concepts=max_concepts,
116
+ device=device,
117
+ cache_path=cache_path,
118
+ )
119
+
120
+ return pool
121
+
122
+
123
+ def print_results(results, max_sv=3, max_heads=4):
124
+ """Pretty-print analysis results."""
125
+ print("\n" + "=" * 80)
126
+ print("UniSITH Analysis Results")
127
+ print("=" * 80)
128
+
129
+ for layer_idx in sorted(results.keys()):
130
+ heads = results[layer_idx]
131
+ print(f"\n{'─' * 80}")
132
+ print(f"LAYER {layer_idx}")
133
+ print(f"{'─' * 80}")
134
+
135
+ for head in heads[:max_heads]:
136
+ print(f"\n Head {head.head_idx}:")
137
+ for sv in head.singular_vectors[:max_sv]:
138
+ print(f" SV {sv.sv_idx} (σ={sv.singular_value:.4f}, "
139
+ f"fidelity={sv.fidelity:.4f}):")
140
+ for caption, coeff in zip(sv.concepts, sv.coefficients):
141
+ print(f" [{coeff:.4f}] {caption}")
142
+
143
+
144
+ def main():
145
+ parser = argparse.ArgumentParser(description="UniSITH: Unimodal SITH Analysis")
146
+ parser.add_argument(
147
+ "--model", type=str, default="facebook/dinov2-base",
148
+ help="Model name/path"
149
+ )
150
+ parser.add_argument(
151
+ "--architecture", type=str, default=None,
152
+ help="Architecture type (auto-detected from model name if not set)"
153
+ )
154
+ parser.add_argument(
155
+ "--max-concepts", type=int, default=1000,
156
+ help="Maximum concepts in the pool"
157
+ )
158
+ parser.add_argument(
159
+ "--layers", type=int, nargs="+", default=None,
160
+ help="Layers to analyze (default: last 4)"
161
+ )
162
+ parser.add_argument(
163
+ "--n-sv", type=int, default=5,
164
+ help="Number of singular vectors per head"
165
+ )
166
+ parser.add_argument(
167
+ "--K", type=int, default=5,
168
+ help="Concepts per singular vector"
169
+ )
170
+ parser.add_argument(
171
+ "--lambda-coh", type=float, default=0.3,
172
+ help="COMP coherence weight"
173
+ )
174
+ parser.add_argument(
175
+ "--method", type=str, default="comp", choices=["comp", "top_k"],
176
+ help="Concept attribution method"
177
+ )
178
+ parser.add_argument(
179
+ "--device", type=str, default="cpu",
180
+ help="Device (cpu/cuda)"
181
+ )
182
+ parser.add_argument(
183
+ "--cache-dir", type=str, default="./cache",
184
+ help="Cache directory for concept embeddings"
185
+ )
186
+ parser.add_argument(
187
+ "--output", type=str, default="./results/unisith_results.json",
188
+ help="Output JSON path"
189
+ )
190
+
191
+ args = parser.parse_args()
192
+
193
+ # Auto-detect architecture
194
+ if args.architecture is None:
195
+ if args.model in MODEL_CONFIGS:
196
+ config = MODEL_CONFIGS[args.model]
197
+ args.architecture = config["architecture"]
198
+ n_heads = config["n_heads"]
199
+ d_model = config["d_model"]
200
+ else:
201
+ raise ValueError(
202
+ f"Unknown model {args.model}. Specify --architecture manually or use "
203
+ f"one of: {list(MODEL_CONFIGS.keys())}"
204
+ )
205
+ else:
206
+ if args.model in MODEL_CONFIGS:
207
+ config = MODEL_CONFIGS[args.model]
208
+ n_heads = config["n_heads"]
209
+ d_model = config["d_model"]
210
+ else:
211
+ raise ValueError(
212
+ f"Model {args.model} not in MODEL_CONFIGS. Add it or specify n_heads/d_model."
213
+ )
214
+
215
+ device = args.device
216
+ if device == "cuda" and not torch.cuda.is_available():
217
+ print("CUDA not available, falling back to CPU")
218
+ device = "cpu"
219
+
220
+ # Load model
221
+ model, processor = load_model_and_processor(args.model, args.architecture)
222
+ model = model.to(device)
223
+
224
+ # Build concept pool
225
+ cache_path = os.path.join(
226
+ args.cache_dir,
227
+ f"concept_pool_{args.model.replace('/', '_')}_{args.max_concepts}.pt"
228
+ )
229
+
230
+ pool = build_concept_pool(
231
+ model=model,
232
+ processor=processor,
233
+ architecture=args.architecture,
234
+ max_concepts=args.max_concepts,
235
+ cache_path=cache_path,
236
+ device=device,
237
+ )
238
+
239
+ print(f"Concept pool: {pool.num_concepts} concepts, dim={pool.embed_dim}")
240
+
241
+ # Create UniSITH analyzer
242
+ analyzer = UniSITH(
243
+ model=model,
244
+ architecture=args.architecture,
245
+ n_heads=n_heads,
246
+ d_model=d_model,
247
+ concept_pool=pool,
248
+ device=device,
249
+ )
250
+
251
+ # Run analysis
252
+ results = analyzer.analyze_model(
253
+ layers=args.layers,
254
+ n_singular_vectors=args.n_sv,
255
+ K=args.K,
256
+ lambda_coh=args.lambda_coh,
257
+ method=args.method,
258
+ )
259
+
260
+ # Print results
261
+ print_results(results)
262
+
263
+ # Save results
264
+ UniSITH.save_results(results, args.output)
265
+
266
+ print(f"\nDone! Results saved to {args.output}")
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main()