Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import os | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from sae_lens import ( | |
| StandardSAE, StandardSAEConfig, | |
| TopKSAE, TopKSAEConfig, | |
| SAE, SAEConfig | |
| ) | |
| from jaxtyping import Float | |
| class SAEManager: | |
| """ | |
| Handles SAE training, latent decomposition, and anomaly detection for DTs. | |
| Supports Standard (ReLU) and TopK architectures. | |
| """ | |
| def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"): | |
| self.model = model | |
| self.sae_dir = sae_dir | |
| self.saes: Dict[str, Union[StandardSAE, TopKSAE]] = {} | |
| os.makedirs(sae_dir, exist_ok=True) | |
| def setup_sae( | |
| self, | |
| hook_point: str, | |
| d_model: int, | |
| expansion_factor: int = 8, | |
| architecture: str = "standard", | |
| k: Optional[int] = None, | |
| ) -> Union[StandardSAE, TopKSAE]: | |
| """Initializes an SAE (Standard or TopK) for a specific hook point.""" | |
| d_sae = d_model * expansion_factor | |
| device = str(next(self.model.parameters()).device) | |
| if architecture == "topk": | |
| if k is None: | |
| k = d_sae // 32 # Default sparsity | |
| cfg = TopKSAEConfig( | |
| d_in=d_model, | |
| d_sae=d_sae, | |
| k=k, | |
| device=device | |
| ) | |
| sae = TopKSAE(cfg) | |
| else: | |
| cfg = StandardSAEConfig( | |
| d_in=d_model, | |
| d_sae=d_sae, | |
| device=device | |
| ) | |
| sae = StandardSAE(cfg) | |
| self.saes[hook_point] = sae | |
| return sae | |
| def train_on_trajectories( | |
| self, | |
| hook_point: str, | |
| activations: Float[torch.Tensor, "n_samples d_model"], | |
| l1_coefficient: float = 0.0001, | |
| batch_size: int = 1024, | |
| epochs: int = 10, | |
| ): | |
| """Trains the SAE on collected activations.""" | |
| if hook_point not in self.saes: | |
| self.setup_sae(hook_point, activations.shape[-1]) | |
| sae = self.saes[hook_point] | |
| optimizer = torch.optim.Adam(sae.parameters(), lr=0.0004) | |
| sae.train() | |
| n_samples = activations.shape[0] | |
| is_topk = isinstance(sae, TopKSAE) | |
| for epoch in range(epochs): | |
| permutation = torch.randperm(n_samples) | |
| epoch_loss = 0 | |
| for i in range(0, n_samples, batch_size): | |
| indices = permutation[i:i+batch_size] | |
| batch_acts = activations[indices].to(sae.device) | |
| optimizer.zero_grad() | |
| feature_acts = sae.encode(batch_acts) | |
| sae_out = sae.decode(feature_acts) | |
| mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts) | |
| if is_topk: | |
| # TopK doesn't use L1; sparsity is enforced by architecture | |
| loss = mse_loss | |
| else: | |
| l1_loss = l1_coefficient * feature_acts.abs().sum() | |
| loss = mse_loss + l1_loss | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss / (n_samples / batch_size):.4f}") | |
| def get_feature_activations( | |
| self, | |
| hook_point: str, | |
| activations: Float[torch.Tensor, "... d_model"] | |
| ) -> Float[torch.Tensor, "... d_sae"]: | |
| """Decomposes activations into latent features.""" | |
| if hook_point not in self.saes: | |
| raise ValueError(f"SAE for {hook_point} not found.") | |
| sae = self.saes[hook_point] | |
| sae.eval() | |
| with torch.no_grad(): | |
| feature_acts = sae.encode(activations.to(sae.device)) | |
| return feature_acts | |
| def reconstruct( | |
| self, | |
| hook_point: str, | |
| activations: Float[torch.Tensor, "... d_model"] | |
| ) -> Float[torch.Tensor, "... d_model"]: | |
| """Reconstructs activations from latents.""" | |
| if hook_point not in self.saes: | |
| raise ValueError(f"SAE for {hook_point} not found.") | |
| sae = self.saes[hook_point] | |
| sae.eval() | |
| with torch.no_grad(): | |
| feature_acts = sae.encode(activations.to(sae.device)) | |
| sae_out = sae.decode(feature_acts) | |
| return sae_out | |
| def compute_anomaly_score( | |
| self, | |
| hook_point: str, | |
| activations: Float[torch.Tensor, "... d_model"] | |
| ) -> Float[torch.Tensor, "..."]: | |
| """ | |
| Reconstruction error for anomaly detection. | |
| """ | |
| if hook_point not in self.saes: | |
| raise ValueError(f"SAE for {hook_point} not found.") | |
| sae = self.saes[hook_point] | |
| sae.eval() | |
| with torch.no_grad(): | |
| x = activations.to(sae.device) | |
| feature_acts = sae.encode(x) | |
| x_hat = sae.decode(feature_acts) | |
| error = torch.norm(x - x_hat, dim=-1) / (torch.norm(x, dim=-1) + 1e-8) | |
| return error | |
| def save_all_saes(self): | |
| for hook, sae in self.saes.items(): | |
| path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt") | |
| torch.save({ | |
| 'state_dict': sae.state_dict(), | |
| 'cfg': sae.cfg, | |
| 'type': 'topk' if isinstance(sae, TopKSAE) else 'standard' | |
| }, path) | |
| print(f"Saved SAE for {hook} to {path}") | |
| def load_sae(self, hook_point: str): | |
| path = os.path.join(self.sae_dir, f"{hook_point.replace('.', '_')}_sae.pt") | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"No saved SAE found at {path}") | |
| checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device), weights_only=False) | |
| if checkpoint.get('type') == 'topk': | |
| sae = TopKSAE(checkpoint['cfg']) | |
| else: | |
| sae = StandardSAE(checkpoint['cfg']) | |
| sae.load_state_dict(checkpoint['state_dict']) | |
| self.saes[hook_point] = sae | |
| return sae | |