Spaces:
Running
Running
File size: 6,133 Bytes
0346604 8577352 0346604 11dbbc6 8577352 0346604 8577352 0346604 8577352 0346604 8577352 0346604 8577352 0346604 8577352 0346604 11dbbc6 0346604 8577352 0346604 11dbbc6 0346604 fa350cc 8577352 fa350cc 0346604 8577352 0346604 8577352 0346604 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import torch
import torch.nn as nn
import os
from typing import Dict, List, Optional, Tuple, Union
from sae_lens import (
StandardSAE, StandardSAEConfig,
TopKSAE, TopKSAEConfig,
SAE, SAEConfig
)
from jaxtyping import Float
class SAEManager:
"""
Handles SAE training, latent decomposition, and anomaly detection for DTs.
Supports Standard (ReLU) and TopK architectures.
"""
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
self.model = model
self.sae_dir = sae_dir
self.saes: Dict[str, Union[StandardSAE, TopKSAE]] = {}
os.makedirs(sae_dir, exist_ok=True)
def setup_sae(
self,
hook_point: str,
d_model: int,
expansion_factor: int = 8,
architecture: str = "standard",
k: Optional[int] = None,
) -> Union[StandardSAE, TopKSAE]:
"""Initializes an SAE (Standard or TopK) for a specific hook point."""
d_sae = d_model * expansion_factor
device = str(next(self.model.parameters()).device)
if architecture == "topk":
if k is None:
k = d_sae // 32 # Default sparsity
cfg = TopKSAEConfig(
d_in=d_model,
d_sae=d_sae,
k=k,
device=device
)
sae = TopKSAE(cfg)
else:
cfg = StandardSAEConfig(
d_in=d_model,
d_sae=d_sae,
device=device
)
sae = StandardSAE(cfg)
self.saes[hook_point] = sae
return sae
def train_on_trajectories(
self,
hook_point: str,
activations: Float[torch.Tensor, "n_samples d_model"],
l1_coefficient: float = 0.0001,
batch_size: int = 1024,
epochs: int = 10,
):
"""Trains the SAE on collected activations."""
if hook_point not in self.saes:
self.setup_sae(hook_point, activations.shape[-1])
sae = self.saes[hook_point]
optimizer = torch.optim.Adam(sae.parameters(), lr=0.0004)
sae.train()
n_samples = activations.shape[0]
is_topk = isinstance(sae, TopKSAE)
for epoch in range(epochs):
permutation = torch.randperm(n_samples)
epoch_loss = 0
for i in range(0, n_samples, batch_size):
indices = permutation[i:i+batch_size]
batch_acts = activations[indices].to(sae.device)
optimizer.zero_grad()
feature_acts = sae.encode(batch_acts)
sae_out = sae.decode(feature_acts)
mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
if is_topk:
# TopK doesn't use L1; sparsity is enforced by architecture
loss = mse_loss
else:
l1_loss = l1_coefficient * feature_acts.abs().sum()
loss = mse_loss + l1_loss
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss / (n_samples / batch_size):.4f}")
def get_feature_activations(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "... d_sae"]:
"""Decomposes activations into latent features."""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
feature_acts = sae.encode(activations.to(sae.device))
return feature_acts
def reconstruct(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "... d_model"]:
"""Reconstructs activations from latents."""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
feature_acts = sae.encode(activations.to(sae.device))
sae_out = sae.decode(feature_acts)
return sae_out
def compute_anomaly_score(
self,
hook_point: str,
activations: Float[torch.Tensor, "... d_model"]
) -> Float[torch.Tensor, "..."]:
"""
Reconstruction error for anomaly detection.
"""
if hook_point not in self.saes:
raise ValueError(f"SAE for {hook_point} not found.")
sae = self.saes[hook_point]
sae.eval()
with torch.no_grad():
x = activations.to(sae.device)
feature_acts = sae.encode(x)
x_hat = sae.decode(feature_acts)
error = torch.norm(x - x_hat, dim=-1) / (torch.norm(x, dim=-1) + 1e-8)
return error
def save_all_saes(self):
for hook, sae in self.saes.items():
path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
torch.save({
'state_dict': sae.state_dict(),
'cfg': sae.cfg,
'type': 'topk' if isinstance(sae, TopKSAE) else 'standard'
}, path)
print(f"Saved SAE for {hook} to {path}")
def load_sae(self, hook_point: str):
path = os.path.join(self.sae_dir, f"{hook_point.replace('.', '_')}_sae.pt")
if not os.path.exists(path):
raise FileNotFoundError(f"No saved SAE found at {path}")
checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device), weights_only=False)
if checkpoint.get('type') == 'topk':
sae = TopKSAE(checkpoint['cfg'])
else:
sae = StandardSAE(checkpoint['cfg'])
sae.load_state_dict(checkpoint['state_dict'])
self.saes[hook_point] = sae
return sae
|