"""Observation-only tooling for Qwen-Scope SAE features. Three utilities — none of them steer or modify generation. They exist to let a reviewer interrogate what the SAE features encode, before any intervention: * encode_prompts(model, tokenizer, sae, prompts, layer) For each prompt, returns the last-token sparse feature code (shape: len(prompts) x n_features). * top_features_for_prompt(...) The N strongest-firing features for a single prompt, with activation values. Equivalent to the read-only path in qwen_scope_steer. * differential_features(pos_codes, neg_codes, top_n) Given two stacks of feature codes, returns features ranked by (mean activation on positive set) - (mean activation on negative set). Useful for "which features distinguish concept A from concept B?" * scan_prompts_for_feature(codes, feature_id) Given a stack of codes and a feature id, returns the per-prompt activation values (zero where the feature didn't make TopK). Nothing here generates steered text. Wire it up to the steering hooks in qwen_scope_steer.py only when intervention is the explicit experimental goal, and document the intervention separately. """ from __future__ import annotations from dataclasses import dataclass from typing import Sequence import torch from qwen_scope_steer import SAE, capture_residual @dataclass class FeatureRanking: feature_id: int score: float # (pos_mean - neg_mean) for differential, or activation for top-features pos_mean: float | None = None neg_mean: float | None = None pos_fire_rate: float | None = None neg_fire_rate: float | None = None def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str], layer_idx: int) -> torch.Tensor: """Encode the last-token residual of each prompt through the SAE. Returns codes of shape (len(prompts), n_features) on CPU float32 for stable downstream stats. No generation is performed. """ codes = [] for p in prompts: inputs = tokenizer(p, return_tensors="pt").to(model.device) with torch.no_grad(), capture_residual(model, layer_idx) as bucket: model(**inputs) h_last = bucket["h"][0, -1].unsqueeze(0) z = sae.encode(h_last)[0] codes.append(z.detach().to("cpu", torch.float32)) return torch.stack(codes, dim=0) def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str, layer_idx: int, top_n: int = 10) -> list[FeatureRanking]: codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0] nz = codes.nonzero(as_tuple=False).flatten() vals = codes[nz] order = vals.argsort(descending=True)[:top_n] return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i])) for i in order] def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor, top_n: int = 20) -> list[FeatureRanking]: """Rank features by their differential firing across two prompt sets. pos_codes: (P, F) feature codes for the "positive" prompt set neg_codes: (N, F) feature codes for the "negative" prompt set Returns top_n features by (pos_mean - neg_mean), with both means and per-set fire rates (fraction of prompts where the feature fired). Read-only — no generation, no steering. """ if pos_codes.shape[1] != neg_codes.shape[1]: raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}") pos_mean = pos_codes.mean(dim=0) neg_mean = neg_codes.mean(dim=0) diff = pos_mean - neg_mean pos_fire = (pos_codes != 0).float().mean(dim=0) neg_fire = (neg_codes != 0).float().mean(dim=0) order = diff.argsort(descending=True)[:top_n] return [ FeatureRanking( feature_id=int(i), score=float(diff[i]), pos_mean=float(pos_mean[i]), neg_mean=float(neg_mean[i]), pos_fire_rate=float(pos_fire[i]), neg_fire_rate=float(neg_fire[i]), ) for i in order ] def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor: """Per-prompt activation vector for a single feature (zero where it didn't make TopK).""" return codes[:, feature_id] def fire_rate(codes: torch.Tensor, feature_id: int) -> float: """Fraction of prompts on which the feature fired (was in TopK).""" return float((codes[:, feature_id] != 0).float().mean()) def pretty_ranking(rs: list[FeatureRanking]) -> str: out = [] for r in rs: if r.pos_mean is not None: out.append( f" feat {r.feature_id:>6d} " f"diff={r.score:+8.4f} " f"pos_mean={r.pos_mean:+8.4f} neg_mean={r.neg_mean:+8.4f} " f"pos_fire={r.pos_fire_rate:.2f} neg_fire={r.neg_fire_rate:.2f}" ) else: out.append(f" feat {r.feature_id:>6d} act={r.score:+8.4f}") return "\n".join(out)