Spaces:
Running
Running
File size: 5,065 Bytes
f2ae1f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Observation-only tooling for Qwen-Scope SAE features.
Three utilities — none of them steer or modify generation. They exist to let a
reviewer interrogate what the SAE features encode, before any intervention:
* encode_prompts(model, tokenizer, sae, prompts, layer)
For each prompt, returns the last-token sparse feature code (shape:
len(prompts) x n_features).
* top_features_for_prompt(...)
The N strongest-firing features for a single prompt, with activation
values. Equivalent to the read-only path in qwen_scope_steer.
* differential_features(pos_codes, neg_codes, top_n)
Given two stacks of feature codes, returns features ranked by
(mean activation on positive set) - (mean activation on negative set).
Useful for "which features distinguish concept A from concept B?"
* scan_prompts_for_feature(codes, feature_id)
Given a stack of codes and a feature id, returns the per-prompt
activation values (zero where the feature didn't make TopK).
Nothing here generates steered text. Wire it up to the steering hooks in
qwen_scope_steer.py only when intervention is the explicit experimental goal,
and document the intervention separately.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import torch
from qwen_scope_steer import SAE, capture_residual
@dataclass
class FeatureRanking:
feature_id: int
score: float # (pos_mean - neg_mean) for differential, or activation for top-features
pos_mean: float | None = None
neg_mean: float | None = None
pos_fire_rate: float | None = None
neg_fire_rate: float | None = None
def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str],
layer_idx: int) -> torch.Tensor:
"""Encode the last-token residual of each prompt through the SAE.
Returns codes of shape (len(prompts), n_features) on CPU float32 for
stable downstream stats. No generation is performed.
"""
codes = []
for p in prompts:
inputs = tokenizer(p, return_tensors="pt").to(model.device)
with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
model(**inputs)
h_last = bucket["h"][0, -1].unsqueeze(0)
z = sae.encode(h_last)[0]
codes.append(z.detach().to("cpu", torch.float32))
return torch.stack(codes, dim=0)
def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str,
layer_idx: int, top_n: int = 10) -> list[FeatureRanking]:
codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0]
nz = codes.nonzero(as_tuple=False).flatten()
vals = codes[nz]
order = vals.argsort(descending=True)[:top_n]
return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i]))
for i in order]
def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor,
top_n: int = 20) -> list[FeatureRanking]:
"""Rank features by their differential firing across two prompt sets.
pos_codes: (P, F) feature codes for the "positive" prompt set
neg_codes: (N, F) feature codes for the "negative" prompt set
Returns top_n features by (pos_mean - neg_mean), with both means and
per-set fire rates (fraction of prompts where the feature fired).
Read-only — no generation, no steering.
"""
if pos_codes.shape[1] != neg_codes.shape[1]:
raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}")
pos_mean = pos_codes.mean(dim=0)
neg_mean = neg_codes.mean(dim=0)
diff = pos_mean - neg_mean
pos_fire = (pos_codes != 0).float().mean(dim=0)
neg_fire = (neg_codes != 0).float().mean(dim=0)
order = diff.argsort(descending=True)[:top_n]
return [
FeatureRanking(
feature_id=int(i),
score=float(diff[i]),
pos_mean=float(pos_mean[i]),
neg_mean=float(neg_mean[i]),
pos_fire_rate=float(pos_fire[i]),
neg_fire_rate=float(neg_fire[i]),
)
for i in order
]
def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor:
"""Per-prompt activation vector for a single feature (zero where it didn't make TopK)."""
return codes[:, feature_id]
def fire_rate(codes: torch.Tensor, feature_id: int) -> float:
"""Fraction of prompts on which the feature fired (was in TopK)."""
return float((codes[:, feature_id] != 0).float().mean())
def pretty_ranking(rs: list[FeatureRanking]) -> str:
out = []
for r in rs:
if r.pos_mean is not None:
out.append(
f" feat {r.feature_id:>6d} "
f"diff={r.score:+8.4f} "
f"pos_mean={r.pos_mean:+8.4f} neg_mean={r.neg_mean:+8.4f} "
f"pos_fire={r.pos_fire_rate:.2f} neg_fire={r.neg_fire_rate:.2f}"
)
else:
out.append(f" feat {r.feature_id:>6d} act={r.score:+8.4f}")
return "\n".join(out)
|