"""Qwen-Scope SAE feature reading + steering for transformers. End-to-end demo: 1. Loads a base Qwen3 model and a matching Qwen-Scope TopK SAE checkpoint. 2. Captures the residual-stream output of a chosen decoder layer. 3. Encodes it through the SAE -> top-K firing features. 4. Generates a baseline completion. 5. Re-generates with feature steering: residual h <- h + alpha * W_dec[:, feat] applied via register_forward_hook on every forward pass. Verified against: * Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50 (W_enc 32768x2048, W_dec 2048x32768, b_enc 32768, b_dec 2048, all float32, K=50) * Qwen/Qwen3-1.7B-Base (28 Qwen3DecoderLayer, hidden_size=2048, layer forward returns bare torch.Tensor under transformers >= 5). """ from __future__ import annotations import argparse import contextlib from dataclasses import dataclass from pathlib import Path import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM, AutoTokenizer # --------------------------------------------------------------------------- # SAE # --------------------------------------------------------------------------- @dataclass class SAE: W_enc: torch.Tensor # (n_features, d_model) W_dec: torch.Tensor # (d_model, n_features) b_enc: torch.Tensor # (n_features,) b_dec: torch.Tensor # (d_model,) k: int # TopK layer: int # layer index this SAE belongs to @classmethod def from_repo(cls, repo: str, layer: int, k: int, device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE": path = hf_hub_download(repo, f"layer{layer}.sae.pt") return cls.from_path(path, layer=layer, k=k, device=device, dtype=dtype) @classmethod def from_path(cls, path: str | Path, layer: int, k: int, device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE": sd = torch.load(str(path), map_location=device, weights_only=True) for key in ("W_enc", "W_dec", "b_enc", "b_dec"): if key not in sd: raise KeyError(f"SAE checkpoint at {path} missing key {key!r}; " f"got {list(sd.keys())}") return cls( W_enc=sd["W_enc"].to(device=device, dtype=dtype), W_dec=sd["W_dec"].to(device=device, dtype=dtype), b_enc=sd["b_enc"].to(device=device, dtype=dtype), b_dec=sd["b_dec"].to(device=device, dtype=dtype), k=k, layer=layer, ) @property def n_features(self) -> int: return self.W_enc.shape[0] @property def d_model(self) -> int: return self.W_enc.shape[1] def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode residual stream activations -> sparse feature codes (TopK).""" x = x.to(device=self.W_enc.device, dtype=self.W_enc.dtype) pre = F.linear(x, self.W_enc, self.b_enc) # (..., n_features) topk_vals, topk_idx = pre.topk(self.k, dim=-1) z = torch.zeros_like(pre) z.scatter_(-1, topk_idx, topk_vals) return z def decode(self, z: torch.Tensor) -> torch.Tensor: z = z.to(device=self.W_dec.device, dtype=self.W_dec.dtype) return F.linear(z, self.W_dec, self.b_dec) def steering_vector(self, feature_id: int) -> torch.Tensor: return self.W_dec[:, feature_id].clone() # --------------------------------------------------------------------------- # Hook helpers # --------------------------------------------------------------------------- def _layer_output_to_tensor(out): """Qwen3DecoderLayer returns torch.Tensor in transformers >= 5, a tuple (hidden_states, ...) in transformers < 5. Handle both.""" if isinstance(out, tuple): return out[0], out return out, None def _rebuild_layer_output(new_h: torch.Tensor, original_out): if original_out is None: return new_h return (new_h, *original_out[1:]) @contextlib.contextmanager def capture_residual(model, layer_idx: int): """Capture the residual-stream output of model.model.layers[layer_idx].""" bucket: dict = {} layer = model.model.layers[layer_idx] def hook(_module, _inp, out): h, _ = _layer_output_to_tensor(out) bucket["h"] = h.detach() return out handle = layer.register_forward_hook(hook) try: yield bucket finally: handle.remove() @contextlib.contextmanager def steer(model, layer_idx: int, direction: torch.Tensor, alpha: float): """Add `alpha * direction` to the residual stream output of layer_idx on every forward pass while the context is active.""" layer = model.model.layers[layer_idx] direction = direction.detach() def hook(_module, _inp, out): h, original = _layer_output_to_tensor(out) d = direction.to(device=h.device, dtype=h.dtype) new_h = h + alpha * d return _rebuild_layer_output(new_h, original) handle = layer.register_forward_hook(hook) try: yield finally: handle.remove() # --------------------------------------------------------------------------- # Pipeline # --------------------------------------------------------------------------- def read_top_features(model, tokenizer, sae: SAE, prompt: str, layer_idx: int, top_n: int = 10): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(), capture_residual(model, layer_idx) as bucket: model(**inputs) h = bucket["h"] # (1, T, d_model) on model.device h_last = h[0, -1].unsqueeze(0) # (1, d_model) — encode() handles device/dtype z = sae.encode(h_last)[0] nonzero = z.nonzero(as_tuple=False).flatten() vals = z[nonzero] order = vals.argsort(descending=True) top = nonzero[order][:top_n] return [(int(f.item()), float(z[f].item())) for f in top] def generate(model, tokenizer, prompt: str, max_new_tokens: int = 40): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # deterministic for A/B comparison pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(out[0], skip_special_tokens=True) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_topk_from_repo(repo: str) -> int: # e.g. "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50" -> 50 suffix = repo.rsplit("L0_", 1) if len(suffix) == 2 and suffix[1].isdigit(): return int(suffix[1]) return 50 def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-1.7B-Base") ap.add_argument("--sae-repo", default="Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50") ap.add_argument("--layer", type=int, default=14) ap.add_argument("--prompt", default="The capital of France is") ap.add_argument("--max-new-tokens", type=int, default=40) ap.add_argument("--alpha", type=float, default=-10.0, help="Steering magnitude. Negative suppresses, positive amplifies.") ap.add_argument("--suppress-rank", type=int, default=0, help="Which top-firing feature (0 = strongest) to steer.") ap.add_argument("--feature-id", type=int, default=None, help="Override: steer this exact feature instead of a top-rank pick.") ap.add_argument("--topk", type=int, default=None, help="Override SAE TopK (auto-detected from repo name).") ap.add_argument("--device", default=None, help="cuda | mps | cpu (auto if omitted)") ap.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"]) args = ap.parse_args() if args.device is None: if torch.cuda.is_available(): args.device = "cuda" elif torch.backends.mps.is_available(): args.device = "mps" else: args.device = "cpu" dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] print(f"[load] model={args.model} device={args.device} dtype={args.dtype}") tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForCausalLM.from_pretrained( args.model, dtype=dtype, device_map=args.device, ) model.eval() n_layers = len(model.model.layers) if not (0 <= args.layer < n_layers): raise ValueError(f"--layer {args.layer} out of range; model has {n_layers} layers") hidden = model.config.hidden_size print(f"[load] {type(model).__name__}: {n_layers} layers, hidden={hidden}") k = args.topk or parse_topk_from_repo(args.sae_repo) print(f"[load] SAE repo={args.sae_repo} layer={args.layer} K={k}") sae = SAE.from_repo(args.sae_repo, layer=args.layer, k=k, device=args.device, dtype=dtype) if sae.d_model != hidden: raise ValueError(f"SAE d_model={sae.d_model} != model hidden_size={hidden}; " f"this SAE doesn't match this model.") # 1. Top features for the prompt print(f"\n[features] top firing at layer {args.layer} for prompt: {args.prompt!r}") top = read_top_features(model, tokenizer, sae, args.prompt, args.layer, top_n=10) for rank, (fid, act) in enumerate(top): print(f" rank {rank:2d} feature {fid:>6d} act={act:+.4f}") # Pick steering target if args.feature_id is not None: target_id = args.feature_id else: target_id = top[args.suppress_rank][0] # 2. Baseline generation print(f"\n[baseline] generating (no steering)...") baseline = generate(model, tokenizer, args.prompt, args.max_new_tokens) print(f" >>> {baseline!r}") # 3. Steered generation print(f"\n[steer] feature {target_id} at layer {args.layer} with alpha={args.alpha}") direction = sae.steering_vector(target_id) with steer(model, args.layer, direction, args.alpha): steered = generate(model, tokenizer, args.prompt, args.max_new_tokens) print(f" >>> {steered!r}") # 4. Verify the steering actually moved the feature inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device) with torch.no_grad(), capture_residual(model, args.layer) as bucket: model(**inputs) base_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item() with torch.no_grad(), steer(model, args.layer, direction, args.alpha), \ capture_residual(model, args.layer) as bucket: model(**inputs) steered_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item() print(f"\n[verify] feature {target_id} activation: baseline={base_act:+.4f} " f"steered={steered_act:+.4f} delta={steered_act - base_act:+.4f}") if args.alpha > 0 and steered_act <= base_act: print(" WARN: alpha>0 but activation didn't go up — unexpected.") if args.alpha < 0 and steered_act >= base_act: print(" WARN: alpha<0 but activation didn't go down — unexpected.") if __name__ == "__main__": main()