OhhMoo's picture
Upload folder using huggingface_hub
0ef92d9 verified
"""Minimal loader for the TopK SAEs in this repository.
Usage:
from loader import load_sae
sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda")
x_hat, z = sae(x) # x: (N, d_model=896)
The `sae(x)` forward returns:
x_hat: (N, d_model) reconstruction
z_sparse: (N, d_sae) sparse code, exactly `k` non-zeros per row
When splicing into the base model's residual stream, only replace
real-token positions (see README for the rationale):
patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h)
"""
import torch
import torch.nn as nn
class TopKSAE(nn.Module):
def __init__(self, d_model: int, d_sae: int, k: int):
super().__init__()
self.k = k
self.d_model = d_model
self.d_sae = d_sae
self.b_pre = nn.Parameter(torch.zeros(d_model))
self.encoder = nn.Linear(d_model, d_sae, bias=True)
self.decoder = nn.Linear(d_sae, d_model, bias=True)
def encode(self, x: torch.Tensor) -> torch.Tensor:
z = self.encoder(x - self.b_pre)
topk_values, topk_indices = torch.topk(z, self.k, dim=-1)
z_sparse = torch.zeros_like(z)
z_sparse.scatter_(-1, topk_indices, topk_values)
return z_sparse
def forward(self, x: torch.Tensor):
z_sparse = self.encode(x)
return self.decoder(z_sparse), z_sparse
def load_sae(path: str, device: str = "cpu"):
"""Load a checkpoint saved by the training pipeline.
Checkpoint format:
{"state_dict": ..., "config": {"d_model", "d_sae", "k", "source"}}
"""
ckpt = torch.load(path, map_location=device, weights_only=False)
cfg = ckpt["config"]
sae = TopKSAE(cfg["d_model"], cfg["d_sae"], cfg["k"])
sae.load_state_dict(ckpt["state_dict"], strict=False)
return sae.to(device).eval(), cfg