Spaces:
Running
Running
| """Qwen-Scope SAE feature reading + steering for transformers. | |
| End-to-end demo: | |
| 1. Loads a base Qwen3 model and a matching Qwen-Scope TopK SAE checkpoint. | |
| 2. Captures the residual-stream output of a chosen decoder layer. | |
| 3. Encodes it through the SAE -> top-K firing features. | |
| 4. Generates a baseline completion. | |
| 5. Re-generates with feature steering: residual h <- h + alpha * W_dec[:, feat] | |
| applied via register_forward_hook on every forward pass. | |
| Verified against: | |
| * Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50 (W_enc 32768x2048, W_dec 2048x32768, | |
| b_enc 32768, b_dec 2048, all float32, K=50) | |
| * Qwen/Qwen3-1.7B-Base (28 Qwen3DecoderLayer, hidden_size=2048, layer forward | |
| returns bare torch.Tensor under transformers >= 5). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import contextlib | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| # SAE | |
| # --------------------------------------------------------------------------- | |
| class SAE: | |
| W_enc: torch.Tensor # (n_features, d_model) | |
| W_dec: torch.Tensor # (d_model, n_features) | |
| b_enc: torch.Tensor # (n_features,) | |
| b_dec: torch.Tensor # (d_model,) | |
| k: int # TopK | |
| layer: int # layer index this SAE belongs to | |
| def from_repo(cls, repo: str, layer: int, k: int, device: str = "cpu", | |
| dtype: torch.dtype = torch.float32) -> "SAE": | |
| path = hf_hub_download(repo, f"layer{layer}.sae.pt") | |
| return cls.from_path(path, layer=layer, k=k, device=device, dtype=dtype) | |
| def from_path(cls, path: str | Path, layer: int, k: int, | |
| device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE": | |
| sd = torch.load(str(path), map_location=device, weights_only=True) | |
| for key in ("W_enc", "W_dec", "b_enc", "b_dec"): | |
| if key not in sd: | |
| raise KeyError(f"SAE checkpoint at {path} missing key {key!r}; " | |
| f"got {list(sd.keys())}") | |
| return cls( | |
| W_enc=sd["W_enc"].to(device=device, dtype=dtype), | |
| W_dec=sd["W_dec"].to(device=device, dtype=dtype), | |
| b_enc=sd["b_enc"].to(device=device, dtype=dtype), | |
| b_dec=sd["b_dec"].to(device=device, dtype=dtype), | |
| k=k, layer=layer, | |
| ) | |
| def n_features(self) -> int: | |
| return self.W_enc.shape[0] | |
| def d_model(self) -> int: | |
| return self.W_enc.shape[1] | |
| def encode(self, x: torch.Tensor) -> torch.Tensor: | |
| """Encode residual stream activations -> sparse feature codes (TopK).""" | |
| x = x.to(device=self.W_enc.device, dtype=self.W_enc.dtype) | |
| pre = F.linear(x, self.W_enc, self.b_enc) # (..., n_features) | |
| topk_vals, topk_idx = pre.topk(self.k, dim=-1) | |
| z = torch.zeros_like(pre) | |
| z.scatter_(-1, topk_idx, topk_vals) | |
| return z | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| z = z.to(device=self.W_dec.device, dtype=self.W_dec.dtype) | |
| return F.linear(z, self.W_dec, self.b_dec) | |
| def steering_vector(self, feature_id: int) -> torch.Tensor: | |
| return self.W_dec[:, feature_id].clone() | |
| # --------------------------------------------------------------------------- | |
| # Hook helpers | |
| # --------------------------------------------------------------------------- | |
| def _layer_output_to_tensor(out): | |
| """Qwen3DecoderLayer returns torch.Tensor in transformers >= 5, | |
| a tuple (hidden_states, ...) in transformers < 5. Handle both.""" | |
| if isinstance(out, tuple): | |
| return out[0], out | |
| return out, None | |
| def _rebuild_layer_output(new_h: torch.Tensor, original_out): | |
| if original_out is None: | |
| return new_h | |
| return (new_h, *original_out[1:]) | |
| def capture_residual(model, layer_idx: int): | |
| """Capture the residual-stream output of model.model.layers[layer_idx].""" | |
| bucket: dict = {} | |
| layer = model.model.layers[layer_idx] | |
| def hook(_module, _inp, out): | |
| h, _ = _layer_output_to_tensor(out) | |
| bucket["h"] = h.detach() | |
| return out | |
| handle = layer.register_forward_hook(hook) | |
| try: | |
| yield bucket | |
| finally: | |
| handle.remove() | |
| def steer(model, layer_idx: int, direction: torch.Tensor, alpha: float): | |
| """Add `alpha * direction` to the residual stream output of layer_idx | |
| on every forward pass while the context is active.""" | |
| layer = model.model.layers[layer_idx] | |
| direction = direction.detach() | |
| def hook(_module, _inp, out): | |
| h, original = _layer_output_to_tensor(out) | |
| d = direction.to(device=h.device, dtype=h.dtype) | |
| new_h = h + alpha * d | |
| return _rebuild_layer_output(new_h, original) | |
| handle = layer.register_forward_hook(hook) | |
| try: | |
| yield | |
| finally: | |
| handle.remove() | |
| # --------------------------------------------------------------------------- | |
| # Pipeline | |
| # --------------------------------------------------------------------------- | |
| def read_top_features(model, tokenizer, sae: SAE, prompt: str, | |
| layer_idx: int, top_n: int = 10): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), capture_residual(model, layer_idx) as bucket: | |
| model(**inputs) | |
| h = bucket["h"] # (1, T, d_model) on model.device | |
| h_last = h[0, -1].unsqueeze(0) # (1, d_model) — encode() handles device/dtype | |
| z = sae.encode(h_last)[0] | |
| nonzero = z.nonzero(as_tuple=False).flatten() | |
| vals = z[nonzero] | |
| order = vals.argsort(descending=True) | |
| top = nonzero[order][:top_n] | |
| return [(int(f.item()), float(z[f].item())) for f in top] | |
| def generate(model, tokenizer, prompt: str, max_new_tokens: int = 40): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, # deterministic for A/B comparison | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(out[0], skip_special_tokens=True) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_topk_from_repo(repo: str) -> int: | |
| # e.g. "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50" -> 50 | |
| suffix = repo.rsplit("L0_", 1) | |
| if len(suffix) == 2 and suffix[1].isdigit(): | |
| return int(suffix[1]) | |
| return 50 | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model", default="Qwen/Qwen3-1.7B-Base") | |
| ap.add_argument("--sae-repo", default="Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50") | |
| ap.add_argument("--layer", type=int, default=14) | |
| ap.add_argument("--prompt", default="The capital of France is") | |
| ap.add_argument("--max-new-tokens", type=int, default=40) | |
| ap.add_argument("--alpha", type=float, default=-10.0, | |
| help="Steering magnitude. Negative suppresses, positive amplifies.") | |
| ap.add_argument("--suppress-rank", type=int, default=0, | |
| help="Which top-firing feature (0 = strongest) to steer.") | |
| ap.add_argument("--feature-id", type=int, default=None, | |
| help="Override: steer this exact feature instead of a top-rank pick.") | |
| ap.add_argument("--topk", type=int, default=None, | |
| help="Override SAE TopK (auto-detected from repo name).") | |
| ap.add_argument("--device", default=None, | |
| help="cuda | mps | cpu (auto if omitted)") | |
| ap.add_argument("--dtype", default="bfloat16", | |
| choices=["bfloat16", "float16", "float32"]) | |
| args = ap.parse_args() | |
| if args.device is None: | |
| if torch.cuda.is_available(): | |
| args.device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| args.device = "mps" | |
| else: | |
| args.device = "cpu" | |
| dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, | |
| "float32": torch.float32}[args.dtype] | |
| print(f"[load] model={args.model} device={args.device} dtype={args.dtype}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, dtype=dtype, device_map=args.device, | |
| ) | |
| model.eval() | |
| n_layers = len(model.model.layers) | |
| if not (0 <= args.layer < n_layers): | |
| raise ValueError(f"--layer {args.layer} out of range; model has {n_layers} layers") | |
| hidden = model.config.hidden_size | |
| print(f"[load] {type(model).__name__}: {n_layers} layers, hidden={hidden}") | |
| k = args.topk or parse_topk_from_repo(args.sae_repo) | |
| print(f"[load] SAE repo={args.sae_repo} layer={args.layer} K={k}") | |
| sae = SAE.from_repo(args.sae_repo, layer=args.layer, k=k, | |
| device=args.device, dtype=dtype) | |
| if sae.d_model != hidden: | |
| raise ValueError(f"SAE d_model={sae.d_model} != model hidden_size={hidden}; " | |
| f"this SAE doesn't match this model.") | |
| # 1. Top features for the prompt | |
| print(f"\n[features] top firing at layer {args.layer} for prompt: {args.prompt!r}") | |
| top = read_top_features(model, tokenizer, sae, args.prompt, args.layer, top_n=10) | |
| for rank, (fid, act) in enumerate(top): | |
| print(f" rank {rank:2d} feature {fid:>6d} act={act:+.4f}") | |
| # Pick steering target | |
| if args.feature_id is not None: | |
| target_id = args.feature_id | |
| else: | |
| target_id = top[args.suppress_rank][0] | |
| # 2. Baseline generation | |
| print(f"\n[baseline] generating (no steering)...") | |
| baseline = generate(model, tokenizer, args.prompt, args.max_new_tokens) | |
| print(f" >>> {baseline!r}") | |
| # 3. Steered generation | |
| print(f"\n[steer] feature {target_id} at layer {args.layer} with alpha={args.alpha}") | |
| direction = sae.steering_vector(target_id) | |
| with steer(model, args.layer, direction, args.alpha): | |
| steered = generate(model, tokenizer, args.prompt, args.max_new_tokens) | |
| print(f" >>> {steered!r}") | |
| # 4. Verify the steering actually moved the feature | |
| inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), capture_residual(model, args.layer) as bucket: | |
| model(**inputs) | |
| base_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item() | |
| with torch.no_grad(), steer(model, args.layer, direction, args.alpha), \ | |
| capture_residual(model, args.layer) as bucket: | |
| model(**inputs) | |
| steered_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item() | |
| print(f"\n[verify] feature {target_id} activation: baseline={base_act:+.4f} " | |
| f"steered={steered_act:+.4f} delta={steered_act - base_act:+.4f}") | |
| if args.alpha > 0 and steered_act <= base_act: | |
| print(" WARN: alpha>0 but activation didn't go up — unexpected.") | |
| if args.alpha < 0 and steered_act >= base_act: | |
| print(" WARN: alpha<0 but activation didn't go down — unexpected.") | |
| if __name__ == "__main__": | |
| main() | |