baguettotron-sae-L48-8x-k16-774m

A TopK sparse autoencoder trained on layer 48 of Baguettotron. 4,608 features (8x expansion from d_model=576), k=16, trained on 773M tokens. Ships with 4,602 autointerp labels.

SAE explorer: https://lyramakesmusic.github.io/bread-slicer/

Architecture

d_in 576
d_sae 4,608
expansion 8x
k 16
tied weights no
hook layer 48 of 80
activation TopK + ReLU
dtype float32

The SAE encodes residual stream activations at layer 48. Forward pass: x @ W_enc + b_enc -> keep top 16 values (ReLU'd) -> f @ W_dec + b_dec.

Training

  • Data: 773M tokens from SYNTH
  • Steps: 94,554
  • LR: 5e-5
  • Final loss: 105.4 (from 11,629 starting)
  • Dead features: 239/4,608 (5.2%)
  • Training time: ~4 hours on RTX 4090
  • Batch size: 8,192 tokens

Files

File Size Description
sae_weights.pt 20 MB SAE weights + config (no optimizer state)
feature_labels.jsonl 651 KB 4,602 autointerp labels with confidence tags
feature_data.json 628 KB UMAP coords + activation stats per feature
train_summary.json 1 KB Training config and final metrics

Loading

import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKSAE(nn.Module):
    def __init__(self, d_in, d_sae, k):
        super().__init__()
        self.d_in, self.d_sae, self.k = d_in, d_sae, k
        self.W_enc = nn.Parameter(torch.empty(d_in, d_sae))
        self.W_dec = nn.Parameter(torch.empty(d_sae, d_in))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_in))
        self.register_buffer("feature_activations", torch.zeros(d_sae, dtype=torch.long))
        self.register_buffer("steps_since_active", torch.zeros(d_sae, dtype=torch.long))

    def forward(self, x):
        pre_acts = x @ self.W_enc + self.b_enc
        topk_vals, topk_idx = torch.topk(pre_acts, self.k, dim=-1)
        topk_vals = F.relu(topk_vals)
        f = torch.zeros_like(pre_acts)
        f.scatter_(-1, topk_idx, topk_vals)
        x_hat = f @ self.W_dec + self.b_dec
        return x_hat, f, topk_idx, topk_vals

# load
ckpt = torch.load("sae_weights.pt", map_location="cpu", weights_only=False)
cfg = ckpt["config"]
sae = TopKSAE(cfg["d_in"], cfg["d_sae"], cfg["k"])
sae.load_state_dict(ckpt["model_state_dict"])
sae.eval().cuda()

Extracting activations from Baguettotron

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "PleIAs/Baguettotron", torch_dtype=torch.bfloat16,
    device_map="cuda", trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("PleIAs/Baguettotron")

# hook layer 48 to grab residual stream
activations = {}
def hook_fn(module, input, output):
    # output is a tuple; first element is the hidden state
    activations["resid"] = output[0].detach().float()

handle = model.model.layers[48].register_forward_hook(hook_fn)

inputs = tokenizer("The cat sat on the", return_tensors="pt").to("cuda")
with torch.no_grad():
    model(**inputs)

x = activations["resid"]  # (batch, seq, 576)
x_hat, f, topk_idx, topk_vals = sae(x)

handle.remove()

Autointerp labels

feature_labels.jsonl has one JSON object per line:

{"feature": 0, "interp": "", "autointerp": "military command and leadership", "confidence": "confident"}

Labels were generated by a finetuned Trinity Nano labeler (mechinterp-v1) and validated against activation data. Confidence is one of confident, tentative, or dead.

Feature data

feature_data.json contains per-feature metadata for building explorers or doing analysis:

{
  "features": [
    {"i": 0, "x": 3.08, "y": 1.94, "x3": 3.20, "y3": 1.55, "z3": 1.68,
     "d": 0.000712, "mx": 296.19, "mn": 74.80, "c": 1178}
  ],
  "meta": { ... }
}

Fields: i = feature index, x/y = 2D UMAP coords, x3/y3/z3 = 3D UMAP coords, d = density, mx = max activation, mn = mean activation, c = fire count.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for lyraaaa/baguettotron-SAE-L48-8x-k16-774m

Finetuned
(11)
this model