DT-Explorer / src /interpretability /sae_manager.py
sadhumitha-s's picture
feat: implement NLA explainer and universality probe and refactor path patching engine
8577352
import torch
import torch.nn as nn
import os
from typing import Dict, List, Optional, Tuple, Union
from sae_lens import (
StandardSAE, StandardSAEConfig,
TopKSAE, TopKSAEConfig,
SAE, SAEConfig
)
from jaxtyping import Float
class SAEManager:
"""
Handles SAE training, latent decomposition, and anomaly detection for DTs.
Supports Standard (ReLU) and TopK architectures.
"""
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
self.model = model
self.sae_dir = sae_dir
self.saes: Dict[str, Union[StandardSAE, TopKSAE]] = {}
os.makedirs(sae_dir, exist_ok=True)
def setup_sae(
self,
hook_point: str,
d_model: int,
expansion_factor: int = 8,
architecture: str = "standard",
k: Optional[int] = None,
) -> Union[StandardSAE, TopKSAE]:
"""Initializes an SAE (Standard or TopK) for a specific hook point."""
d_sae = d_model * expansion_factor
device = str(next(self.model.parameters()).device)
if architecture == "topk":
if k is None:
k = d_sae // 32 # Default sparsity
cfg = TopKSAEConfig(
d_in=d_model,
d_sae=d_sae,
k=k,
device=device
)
sae = TopKSAE(cfg)
else:
cfg = StandardSAEConfig(
d_in=d_model,
d_sae=d_sae,
device=device
)
sae = StandardSAE(cfg)
self.saes[hook_point] = sae
return sae
def train_on_trajectories(
self,
hook_point: str,
activations: Float[torch.Tensor, "n_samples d_model"],
l1_coefficient: float = 0.0001,
batch_size: int = 1024,
epochs: int = 10,
):
"""Trains the SAE on collected activations."""
if hook_point not in self.saes:
self.setup_sae(hook_point, activations.shape[-1])
sae = self.saes[hook_point]
optimizer = torch.optim.Adam(sae.parameters(), lr=0.0004)
sae.train()
n_samples = activations.shape[0]
is_topk = isinstance(sae, TopKSAE)
for epoch in range(epochs):
permutation = torch.randperm(n_samples)
epoch_loss = 0
for i in range(0, n_samples, batch_size):
indices = permutation[i:i+batch_size]
batch_acts = activations[indices].to(sae.device)
optimizer.zero_grad()
feature_acts = sae.encode(batch_acts)
sae_out = sae.decode(feature_acts)
mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
if is_topk:
# TopK doesn't use L1; sparsity is enforced by architecture
loss = mse_loss
else:
l1_loss = l1_coefficient * feature_acts.abs().sum()
loss = mse_loss + l1_loss
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss / (n_samples / batch_size):.4f}")
def get_feature_activations(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "... d_sae"]:
"""Decomposes activations into latent features."""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
feature_acts = sae.encode(activations.to(sae.device))
return feature_acts
def reconstruct(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "... d_model"]:
"""Reconstructs activations from latents."""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
feature_acts = sae.encode(activations.to(sae.device))
sae_out = sae.decode(feature_acts)
return sae_out
def compute_anomaly_score(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "..."]:
"""
Reconstruction error for anomaly detection.
"""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
x = activations.to(sae.device)
feature_acts = sae.encode(x)
x_hat = sae.decode(feature_acts)
error = torch.norm(x - x_hat, dim=-1) / (torch.norm(x, dim=-1) + 1e-8)
return error
def save_all_saes(self):
for hook, sae in self.saes.items():
path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
torch.save({
'state_dict': sae.state_dict(),
'cfg': sae.cfg,
'type': 'topk' if isinstance(sae, TopKSAE) else 'standard'
}, path)
print(f"Saved SAE for {hook} to {path}")
def load_sae(self, hook_point: str):
path = os.path.join(self.sae_dir, f"{hook_point.replace('.', '_')}_sae.pt")
if not os.path.exists(path):
raise FileNotFoundError(f"No saved SAE found at {path}")
checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device), weights_only=False)
if checkpoint.get('type') == 'topk':
sae = TopKSAE(checkpoint['cfg'])
else:
sae = StandardSAE(checkpoint['cfg'])
sae.load_state_dict(checkpoint['state_dict'])
self.saes[hook_point] = sae
return sae