UniSITH / run_unisith.py
lorenzovaquero's picture
Add CLI runner script
24e41d7 verified
#!/usr/bin/env python3
"""
UniSITH Demo: Analyze a DINOv2 model using captioned images as concept pool.
This script demonstrates the full UniSITH pipeline:
1. Load a unimodal ViT model (DINOv2-large)
2. Build a visual concept pool from Recap-COCO-30K
3. Analyze attention heads via SVD + COMP
4. Display human-interpretable concept attributions
Usage:
python run_unisith.py --model facebook/dinov2-large --max-concepts 1000
python run_unisith.py --model openai/clip-vit-large-patch14 --architecture clip
"""
import argparse
import torch
import os
import sys
import json
from transformers import AutoModel, AutoProcessor, AutoImageProcessor
from transformers import CLIPModel, CLIPProcessor
from datasets import load_dataset
# Add parent dir to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from unimodal_sith.concept_pool import VisualConceptPool
from unimodal_sith.unisith import UniSITH
# Model configurations
MODEL_CONFIGS = {
"facebook/dinov2-large": {
"architecture": "dinov2",
"n_heads": 16,
"d_model": 1024,
},
"facebook/dinov2-base": {
"architecture": "dinov2",
"n_heads": 12,
"d_model": 768,
},
"facebook/dinov2-small": {
"architecture": "dinov2",
"n_heads": 6,
"d_model": 384,
},
"openai/clip-vit-large-patch14": {
"architecture": "clip",
"n_heads": 16,
"d_model": 1024,
},
"openai/clip-vit-base-patch16": {
"architecture": "clip",
"n_heads": 12,
"d_model": 768,
},
"google/vit-large-patch16-224": {
"architecture": "vit",
"n_heads": 16,
"d_model": 1024,
},
"google/vit-base-patch16-224": {
"architecture": "vit",
"n_heads": 12,
"d_model": 768,
},
}
def load_model_and_processor(model_name: str, architecture: str):
"""Load model and processor based on architecture type."""
print(f"Loading model: {model_name}")
if architecture == "clip":
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
elif architecture == "dinov2":
model = AutoModel.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
elif architecture == "vit":
model = AutoModel.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
else:
raise ValueError(f"Unknown architecture: {architecture}")
model.eval()
return model, processor
def build_concept_pool(
model,
processor,
architecture: str,
max_concepts: int = 1000,
cache_path: str = None,
device: str = "cpu",
):
"""Build visual concept pool from Recap-COCO-30K."""
print(f"Building concept pool with {max_concepts} concepts...")
# Load dataset
dataset = load_dataset("UCSC-VLAA/Recap-COCO-30K", split="train")
pool = VisualConceptPool.from_dataset(
dataset=dataset,
model=model,
processor=processor,
architecture=architecture,
image_column="image",
caption_column="caption", # Short COCO captions for readability
image_id_column="image_id",
batch_size=32,
max_concepts=max_concepts,
device=device,
cache_path=cache_path,
)
return pool
def print_results(results, max_sv=3, max_heads=4):
"""Pretty-print analysis results."""
print("\n" + "=" * 80)
print("UniSITH Analysis Results")
print("=" * 80)
for layer_idx in sorted(results.keys()):
heads = results[layer_idx]
print(f"\n{'─' * 80}")
print(f"LAYER {layer_idx}")
print(f"{'─' * 80}")
for head in heads[:max_heads]:
print(f"\n Head {head.head_idx}:")
for sv in head.singular_vectors[:max_sv]:
print(f" SV {sv.sv_idx} (σ={sv.singular_value:.4f}, "
f"fidelity={sv.fidelity:.4f}):")
for caption, coeff in zip(sv.concepts, sv.coefficients):
print(f" [{coeff:.4f}] {caption}")
def main():
parser = argparse.ArgumentParser(description="UniSITH: Unimodal SITH Analysis")
parser.add_argument(
"--model", type=str, default="facebook/dinov2-base",
help="Model name/path"
)
parser.add_argument(
"--architecture", type=str, default=None,
help="Architecture type (auto-detected from model name if not set)"
)
parser.add_argument(
"--max-concepts", type=int, default=1000,
help="Maximum concepts in the pool"
)
parser.add_argument(
"--layers", type=int, nargs="+", default=None,
help="Layers to analyze (default: last 4)"
)
parser.add_argument(
"--n-sv", type=int, default=5,
help="Number of singular vectors per head"
)
parser.add_argument(
"--K", type=int, default=5,
help="Concepts per singular vector"
)
parser.add_argument(
"--lambda-coh", type=float, default=0.3,
help="COMP coherence weight"
)
parser.add_argument(
"--method", type=str, default="comp", choices=["comp", "top_k"],
help="Concept attribution method"
)
parser.add_argument(
"--device", type=str, default="cpu",
help="Device (cpu/cuda)"
)
parser.add_argument(
"--cache-dir", type=str, default="./cache",
help="Cache directory for concept embeddings"
)
parser.add_argument(
"--output", type=str, default="./results/unisith_results.json",
help="Output JSON path"
)
args = parser.parse_args()
# Auto-detect architecture
if args.architecture is None:
if args.model in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.model]
args.architecture = config["architecture"]
n_heads = config["n_heads"]
d_model = config["d_model"]
else:
raise ValueError(
f"Unknown model {args.model}. Specify --architecture manually or use "
f"one of: {list(MODEL_CONFIGS.keys())}"
)
else:
if args.model in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.model]
n_heads = config["n_heads"]
d_model = config["d_model"]
else:
raise ValueError(
f"Model {args.model} not in MODEL_CONFIGS. Add it or specify n_heads/d_model."
)
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
device = "cpu"
# Load model
model, processor = load_model_and_processor(args.model, args.architecture)
model = model.to(device)
# Build concept pool
cache_path = os.path.join(
args.cache_dir,
f"concept_pool_{args.model.replace('/', '_')}_{args.max_concepts}.pt"
)
pool = build_concept_pool(
model=model,
processor=processor,
architecture=args.architecture,
max_concepts=args.max_concepts,
cache_path=cache_path,
device=device,
)
print(f"Concept pool: {pool.num_concepts} concepts, dim={pool.embed_dim}")
# Create UniSITH analyzer
analyzer = UniSITH(
model=model,
architecture=args.architecture,
n_heads=n_heads,
d_model=d_model,
concept_pool=pool,
device=device,
)
# Run analysis
results = analyzer.analyze_model(
layers=args.layers,
n_singular_vectors=args.n_sv,
K=args.K,
lambda_coh=args.lambda_coh,
method=args.method,
)
# Print results
print_results(results)
# Save results
UniSITH.save_results(results, args.output)
print(f"\nDone! Results saved to {args.output}")
if __name__ == "__main__":
main()