"""Minimal loader for the TopK SAEs in this repository. Usage: from loader import load_sae sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda") x_hat, z = sae(x) # x: (N, d_model=896) The `sae(x)` forward returns: x_hat: (N, d_model) reconstruction z_sparse: (N, d_sae) sparse code, exactly `k` non-zeros per row When splicing into the base model's residual stream, only replace real-token positions (see README for the rationale): patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h) """ import torch import torch.nn as nn class TopKSAE(nn.Module): def __init__(self, d_model: int, d_sae: int, k: int): super().__init__() self.k = k self.d_model = d_model self.d_sae = d_sae self.b_pre = nn.Parameter(torch.zeros(d_model)) self.encoder = nn.Linear(d_model, d_sae, bias=True) self.decoder = nn.Linear(d_sae, d_model, bias=True) def encode(self, x: torch.Tensor) -> torch.Tensor: z = self.encoder(x - self.b_pre) topk_values, topk_indices = torch.topk(z, self.k, dim=-1) z_sparse = torch.zeros_like(z) z_sparse.scatter_(-1, topk_indices, topk_values) return z_sparse def forward(self, x: torch.Tensor): z_sparse = self.encode(x) return self.decoder(z_sparse), z_sparse def load_sae(path: str, device: str = "cpu"): """Load a checkpoint saved by the training pipeline. Checkpoint format: {"state_dict": ..., "config": {"d_model", "d_sae", "k", "source"}} """ ckpt = torch.load(path, map_location=device, weights_only=False) cfg = ckpt["config"] sae = TopKSAE(cfg["d_model"], cfg["d_sae"], cfg["k"]) sae.load_state_dict(ckpt["state_dict"], strict=False) return sae.to(device).eval(), cfg