qwen-scope-live / qwen_scope_obs.py
Ex0bit's picture
initial qwen-scope-live deploy
f2ae1f5 verified
"""Observation-only tooling for Qwen-Scope SAE features.
Three utilities — none of them steer or modify generation. They exist to let a
reviewer interrogate what the SAE features encode, before any intervention:
* encode_prompts(model, tokenizer, sae, prompts, layer)
For each prompt, returns the last-token sparse feature code (shape:
len(prompts) x n_features).
* top_features_for_prompt(...)
The N strongest-firing features for a single prompt, with activation
values. Equivalent to the read-only path in qwen_scope_steer.
* differential_features(pos_codes, neg_codes, top_n)
Given two stacks of feature codes, returns features ranked by
(mean activation on positive set) - (mean activation on negative set).
Useful for "which features distinguish concept A from concept B?"
* scan_prompts_for_feature(codes, feature_id)
Given a stack of codes and a feature id, returns the per-prompt
activation values (zero where the feature didn't make TopK).
Nothing here generates steered text. Wire it up to the steering hooks in
qwen_scope_steer.py only when intervention is the explicit experimental goal,
and document the intervention separately.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import torch
from qwen_scope_steer import SAE, capture_residual
@dataclass
class FeatureRanking:
feature_id: int
score: float # (pos_mean - neg_mean) for differential, or activation for top-features
pos_mean: float | None = None
neg_mean: float | None = None
pos_fire_rate: float | None = None
neg_fire_rate: float | None = None
def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str],
layer_idx: int) -> torch.Tensor:
"""Encode the last-token residual of each prompt through the SAE.
Returns codes of shape (len(prompts), n_features) on CPU float32 for
stable downstream stats. No generation is performed.
"""
codes = []
for p in prompts:
inputs = tokenizer(p, return_tensors="pt").to(model.device)
with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
model(**inputs)
h_last = bucket["h"][0, -1].unsqueeze(0)
z = sae.encode(h_last)[0]
codes.append(z.detach().to("cpu", torch.float32))
return torch.stack(codes, dim=0)
def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str,
layer_idx: int, top_n: int = 10) -> list[FeatureRanking]:
codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0]
nz = codes.nonzero(as_tuple=False).flatten()
vals = codes[nz]
order = vals.argsort(descending=True)[:top_n]
return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i]))
for i in order]
def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor,
top_n: int = 20) -> list[FeatureRanking]:
"""Rank features by their differential firing across two prompt sets.
pos_codes: (P, F) feature codes for the "positive" prompt set
neg_codes: (N, F) feature codes for the "negative" prompt set
Returns top_n features by (pos_mean - neg_mean), with both means and
per-set fire rates (fraction of prompts where the feature fired).
Read-only — no generation, no steering.
"""
if pos_codes.shape[1] != neg_codes.shape[1]:
raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}")
pos_mean = pos_codes.mean(dim=0)
neg_mean = neg_codes.mean(dim=0)
diff = pos_mean - neg_mean
pos_fire = (pos_codes != 0).float().mean(dim=0)
neg_fire = (neg_codes != 0).float().mean(dim=0)
order = diff.argsort(descending=True)[:top_n]
return [
FeatureRanking(
feature_id=int(i),
score=float(diff[i]),
pos_mean=float(pos_mean[i]),
neg_mean=float(neg_mean[i]),
pos_fire_rate=float(pos_fire[i]),
neg_fire_rate=float(neg_fire[i]),
)
for i in order
]
def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor:
"""Per-prompt activation vector for a single feature (zero where it didn't make TopK)."""
return codes[:, feature_id]
def fire_rate(codes: torch.Tensor, feature_id: int) -> float:
"""Fraction of prompts on which the feature fired (was in TopK)."""
return float((codes[:, feature_id] != 0).float().mean())
def pretty_ranking(rs: list[FeatureRanking]) -> str:
out = []
for r in rs:
if r.pos_mean is not None:
out.append(
f" feat {r.feature_id:>6d} "
f"diff={r.score:+8.4f} "
f"pos_mean={r.pos_mean:+8.4f} neg_mean={r.neg_mean:+8.4f} "
f"pos_fire={r.pos_fire_rate:.2f} neg_fire={r.neg_fire_rate:.2f}"
)
else:
out.append(f" feat {r.feature_id:>6d} act={r.score:+8.4f}")
return "\n".join(out)