baguettotron-sae-L48-8x-k16-774m
A TopK sparse autoencoder trained on layer 48 of Baguettotron. 4,608 features (8x expansion from d_model=576), k=16, trained on 773M tokens. Ships with 4,602 autointerp labels.
SAE explorer: https://lyramakesmusic.github.io/bread-slicer/
Architecture
| d_in | 576 |
| d_sae | 4,608 |
| expansion | 8x |
| k | 16 |
| tied weights | no |
| hook layer | 48 of 80 |
| activation | TopK + ReLU |
| dtype | float32 |
The SAE encodes residual stream activations at layer 48. Forward pass: x @ W_enc + b_enc -> keep top 16 values (ReLU'd) -> f @ W_dec + b_dec.
Training
- Data: 773M tokens from SYNTH
- Steps: 94,554
- LR: 5e-5
- Final loss: 105.4 (from 11,629 starting)
- Dead features: 239/4,608 (5.2%)
- Training time: ~4 hours on RTX 4090
- Batch size: 8,192 tokens
Files
| File | Size | Description |
|---|---|---|
sae_weights.pt |
20 MB | SAE weights + config (no optimizer state) |
feature_labels.jsonl |
651 KB | 4,602 autointerp labels with confidence tags |
feature_data.json |
628 KB | UMAP coords + activation stats per feature |
train_summary.json |
1 KB | Training config and final metrics |
Loading
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopKSAE(nn.Module):
def __init__(self, d_in, d_sae, k):
super().__init__()
self.d_in, self.d_sae, self.k = d_in, d_sae, k
self.W_enc = nn.Parameter(torch.empty(d_in, d_sae))
self.W_dec = nn.Parameter(torch.empty(d_sae, d_in))
self.b_enc = nn.Parameter(torch.zeros(d_sae))
self.b_dec = nn.Parameter(torch.zeros(d_in))
self.register_buffer("feature_activations", torch.zeros(d_sae, dtype=torch.long))
self.register_buffer("steps_since_active", torch.zeros(d_sae, dtype=torch.long))
def forward(self, x):
pre_acts = x @ self.W_enc + self.b_enc
topk_vals, topk_idx = torch.topk(pre_acts, self.k, dim=-1)
topk_vals = F.relu(topk_vals)
f = torch.zeros_like(pre_acts)
f.scatter_(-1, topk_idx, topk_vals)
x_hat = f @ self.W_dec + self.b_dec
return x_hat, f, topk_idx, topk_vals
# load
ckpt = torch.load("sae_weights.pt", map_location="cpu", weights_only=False)
cfg = ckpt["config"]
sae = TopKSAE(cfg["d_in"], cfg["d_sae"], cfg["k"])
sae.load_state_dict(ckpt["model_state_dict"])
sae.eval().cuda()
Extracting activations from Baguettotron
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"PleIAs/Baguettotron", torch_dtype=torch.bfloat16,
device_map="cuda", trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("PleIAs/Baguettotron")
# hook layer 48 to grab residual stream
activations = {}
def hook_fn(module, input, output):
# output is a tuple; first element is the hidden state
activations["resid"] = output[0].detach().float()
handle = model.model.layers[48].register_forward_hook(hook_fn)
inputs = tokenizer("The cat sat on the", return_tensors="pt").to("cuda")
with torch.no_grad():
model(**inputs)
x = activations["resid"] # (batch, seq, 576)
x_hat, f, topk_idx, topk_vals = sae(x)
handle.remove()
Autointerp labels
feature_labels.jsonl has one JSON object per line:
{"feature": 0, "interp": "", "autointerp": "military command and leadership", "confidence": "confident"}
Labels were generated by a finetuned Trinity Nano labeler (mechinterp-v1) and validated against activation data. Confidence is one of confident, tentative, or dead.
Feature data
feature_data.json contains per-feature metadata for building explorers or doing analysis:
{
"features": [
{"i": 0, "x": 3.08, "y": 1.94, "x3": 3.20, "y3": 1.55, "z3": 1.68,
"d": 0.000712, "mx": 296.19, "mn": 74.80, "c": 1178}
],
"meta": { ... }
}
Fields: i = feature index, x/y = 2D UMAP coords, x3/y3/z3 = 3D UMAP coords, d = density, mx = max activation, mn = mean activation, c = fire count.
Model tree for lyraaaa/baguettotron-SAE-L48-8x-k16-774m
Base model
PleIAs/Baguettotron