qwen-scope-live / qwen_scope_steer.py
Ex0bit's picture
initial qwen-scope-live deploy
f2ae1f5 verified
"""Qwen-Scope SAE feature reading + steering for transformers.
End-to-end demo:
1. Loads a base Qwen3 model and a matching Qwen-Scope TopK SAE checkpoint.
2. Captures the residual-stream output of a chosen decoder layer.
3. Encodes it through the SAE -> top-K firing features.
4. Generates a baseline completion.
5. Re-generates with feature steering: residual h <- h + alpha * W_dec[:, feat]
applied via register_forward_hook on every forward pass.
Verified against:
* Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50 (W_enc 32768x2048, W_dec 2048x32768,
b_enc 32768, b_dec 2048, all float32, K=50)
* Qwen/Qwen3-1.7B-Base (28 Qwen3DecoderLayer, hidden_size=2048, layer forward
returns bare torch.Tensor under transformers >= 5).
"""
from __future__ import annotations
import argparse
import contextlib
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---------------------------------------------------------------------------
# SAE
# ---------------------------------------------------------------------------
@dataclass
class SAE:
W_enc: torch.Tensor # (n_features, d_model)
W_dec: torch.Tensor # (d_model, n_features)
b_enc: torch.Tensor # (n_features,)
b_dec: torch.Tensor # (d_model,)
k: int # TopK
layer: int # layer index this SAE belongs to
@classmethod
def from_repo(cls, repo: str, layer: int, k: int, device: str = "cpu",
dtype: torch.dtype = torch.float32) -> "SAE":
path = hf_hub_download(repo, f"layer{layer}.sae.pt")
return cls.from_path(path, layer=layer, k=k, device=device, dtype=dtype)
@classmethod
def from_path(cls, path: str | Path, layer: int, k: int,
device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE":
sd = torch.load(str(path), map_location=device, weights_only=True)
for key in ("W_enc", "W_dec", "b_enc", "b_dec"):
if key not in sd:
raise KeyError(f"SAE checkpoint at {path} missing key {key!r}; "
f"got {list(sd.keys())}")
return cls(
W_enc=sd["W_enc"].to(device=device, dtype=dtype),
W_dec=sd["W_dec"].to(device=device, dtype=dtype),
b_enc=sd["b_enc"].to(device=device, dtype=dtype),
b_dec=sd["b_dec"].to(device=device, dtype=dtype),
k=k, layer=layer,
)
@property
def n_features(self) -> int:
return self.W_enc.shape[0]
@property
def d_model(self) -> int:
return self.W_enc.shape[1]
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode residual stream activations -> sparse feature codes (TopK)."""
x = x.to(device=self.W_enc.device, dtype=self.W_enc.dtype)
pre = F.linear(x, self.W_enc, self.b_enc) # (..., n_features)
topk_vals, topk_idx = pre.topk(self.k, dim=-1)
z = torch.zeros_like(pre)
z.scatter_(-1, topk_idx, topk_vals)
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
z = z.to(device=self.W_dec.device, dtype=self.W_dec.dtype)
return F.linear(z, self.W_dec, self.b_dec)
def steering_vector(self, feature_id: int) -> torch.Tensor:
return self.W_dec[:, feature_id].clone()
# ---------------------------------------------------------------------------
# Hook helpers
# ---------------------------------------------------------------------------
def _layer_output_to_tensor(out):
"""Qwen3DecoderLayer returns torch.Tensor in transformers >= 5,
a tuple (hidden_states, ...) in transformers < 5. Handle both."""
if isinstance(out, tuple):
return out[0], out
return out, None
def _rebuild_layer_output(new_h: torch.Tensor, original_out):
if original_out is None:
return new_h
return (new_h, *original_out[1:])
@contextlib.contextmanager
def capture_residual(model, layer_idx: int):
"""Capture the residual-stream output of model.model.layers[layer_idx]."""
bucket: dict = {}
layer = model.model.layers[layer_idx]
def hook(_module, _inp, out):
h, _ = _layer_output_to_tensor(out)
bucket["h"] = h.detach()
return out
handle = layer.register_forward_hook(hook)
try:
yield bucket
finally:
handle.remove()
@contextlib.contextmanager
def steer(model, layer_idx: int, direction: torch.Tensor, alpha: float):
"""Add `alpha * direction` to the residual stream output of layer_idx
on every forward pass while the context is active."""
layer = model.model.layers[layer_idx]
direction = direction.detach()
def hook(_module, _inp, out):
h, original = _layer_output_to_tensor(out)
d = direction.to(device=h.device, dtype=h.dtype)
new_h = h + alpha * d
return _rebuild_layer_output(new_h, original)
handle = layer.register_forward_hook(hook)
try:
yield
finally:
handle.remove()
# ---------------------------------------------------------------------------
# Pipeline
# ---------------------------------------------------------------------------
def read_top_features(model, tokenizer, sae: SAE, prompt: str,
layer_idx: int, top_n: int = 10):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
model(**inputs)
h = bucket["h"] # (1, T, d_model) on model.device
h_last = h[0, -1].unsqueeze(0) # (1, d_model) — encode() handles device/dtype
z = sae.encode(h_last)[0]
nonzero = z.nonzero(as_tuple=False).flatten()
vals = z[nonzero]
order = vals.argsort(descending=True)
top = nonzero[order][:top_n]
return [(int(f.item()), float(z[f].item())) for f in top]
def generate(model, tokenizer, prompt: str, max_new_tokens: int = 40):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # deterministic for A/B comparison
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(out[0], skip_special_tokens=True)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_topk_from_repo(repo: str) -> int:
# e.g. "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50" -> 50
suffix = repo.rsplit("L0_", 1)
if len(suffix) == 2 and suffix[1].isdigit():
return int(suffix[1])
return 50
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-1.7B-Base")
ap.add_argument("--sae-repo", default="Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50")
ap.add_argument("--layer", type=int, default=14)
ap.add_argument("--prompt", default="The capital of France is")
ap.add_argument("--max-new-tokens", type=int, default=40)
ap.add_argument("--alpha", type=float, default=-10.0,
help="Steering magnitude. Negative suppresses, positive amplifies.")
ap.add_argument("--suppress-rank", type=int, default=0,
help="Which top-firing feature (0 = strongest) to steer.")
ap.add_argument("--feature-id", type=int, default=None,
help="Override: steer this exact feature instead of a top-rank pick.")
ap.add_argument("--topk", type=int, default=None,
help="Override SAE TopK (auto-detected from repo name).")
ap.add_argument("--device", default=None,
help="cuda | mps | cpu (auto if omitted)")
ap.add_argument("--dtype", default="bfloat16",
choices=["bfloat16", "float16", "float32"])
args = ap.parse_args()
if args.device is None:
if torch.cuda.is_available():
args.device = "cuda"
elif torch.backends.mps.is_available():
args.device = "mps"
else:
args.device = "cpu"
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
"float32": torch.float32}[args.dtype]
print(f"[load] model={args.model} device={args.device} dtype={args.dtype}")
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(
args.model, dtype=dtype, device_map=args.device,
)
model.eval()
n_layers = len(model.model.layers)
if not (0 <= args.layer < n_layers):
raise ValueError(f"--layer {args.layer} out of range; model has {n_layers} layers")
hidden = model.config.hidden_size
print(f"[load] {type(model).__name__}: {n_layers} layers, hidden={hidden}")
k = args.topk or parse_topk_from_repo(args.sae_repo)
print(f"[load] SAE repo={args.sae_repo} layer={args.layer} K={k}")
sae = SAE.from_repo(args.sae_repo, layer=args.layer, k=k,
device=args.device, dtype=dtype)
if sae.d_model != hidden:
raise ValueError(f"SAE d_model={sae.d_model} != model hidden_size={hidden}; "
f"this SAE doesn't match this model.")
# 1. Top features for the prompt
print(f"\n[features] top firing at layer {args.layer} for prompt: {args.prompt!r}")
top = read_top_features(model, tokenizer, sae, args.prompt, args.layer, top_n=10)
for rank, (fid, act) in enumerate(top):
print(f" rank {rank:2d} feature {fid:>6d} act={act:+.4f}")
# Pick steering target
if args.feature_id is not None:
target_id = args.feature_id
else:
target_id = top[args.suppress_rank][0]
# 2. Baseline generation
print(f"\n[baseline] generating (no steering)...")
baseline = generate(model, tokenizer, args.prompt, args.max_new_tokens)
print(f" >>> {baseline!r}")
# 3. Steered generation
print(f"\n[steer] feature {target_id} at layer {args.layer} with alpha={args.alpha}")
direction = sae.steering_vector(target_id)
with steer(model, args.layer, direction, args.alpha):
steered = generate(model, tokenizer, args.prompt, args.max_new_tokens)
print(f" >>> {steered!r}")
# 4. Verify the steering actually moved the feature
inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device)
with torch.no_grad(), capture_residual(model, args.layer) as bucket:
model(**inputs)
base_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
with torch.no_grad(), steer(model, args.layer, direction, args.alpha), \
capture_residual(model, args.layer) as bucket:
model(**inputs)
steered_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
print(f"\n[verify] feature {target_id} activation: baseline={base_act:+.4f} "
f"steered={steered_act:+.4f} delta={steered_act - base_act:+.4f}")
if args.alpha > 0 and steered_act <= base_act:
print(" WARN: alpha>0 but activation didn't go up — unexpected.")
if args.alpha < 0 and steered_act >= base_act:
print(" WARN: alpha<0 but activation didn't go down — unexpected.")
if __name__ == "__main__":
main()