| """Minimal loader for the TopK SAEs in this repository. |
| |
| Usage: |
| from loader import load_sae |
| sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda") |
| x_hat, z = sae(x) # x: (N, d_model=896) |
| |
| The `sae(x)` forward returns: |
| x_hat: (N, d_model) reconstruction |
| z_sparse: (N, d_sae) sparse code, exactly `k` non-zeros per row |
| |
| When splicing into the base model's residual stream, only replace |
| real-token positions (see README for the rationale): |
| patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class TopKSAE(nn.Module): |
| def __init__(self, d_model: int, d_sae: int, k: int): |
| super().__init__() |
| self.k = k |
| self.d_model = d_model |
| self.d_sae = d_sae |
| self.b_pre = nn.Parameter(torch.zeros(d_model)) |
| self.encoder = nn.Linear(d_model, d_sae, bias=True) |
| self.decoder = nn.Linear(d_sae, d_model, bias=True) |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.encoder(x - self.b_pre) |
| topk_values, topk_indices = torch.topk(z, self.k, dim=-1) |
| z_sparse = torch.zeros_like(z) |
| z_sparse.scatter_(-1, topk_indices, topk_values) |
| return z_sparse |
|
|
| def forward(self, x: torch.Tensor): |
| z_sparse = self.encode(x) |
| return self.decoder(z_sparse), z_sparse |
|
|
|
|
| def load_sae(path: str, device: str = "cpu"): |
| """Load a checkpoint saved by the training pipeline. |
| |
| Checkpoint format: |
| {"state_dict": ..., "config": {"d_model", "d_sae", "k", "source"}} |
| """ |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
| cfg = ckpt["config"] |
| sae = TopKSAE(cfg["d_model"], cfg["d_sae"], cfg["k"]) |
| sae.load_state_dict(ckpt["state_dict"], strict=False) |
| return sae.to(device).eval(), cfg |
|
|