File size: 5,065 Bytes
f2ae1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Observation-only tooling for Qwen-Scope SAE features.

Three utilities — none of them steer or modify generation. They exist to let a
reviewer interrogate what the SAE features encode, before any intervention:

  * encode_prompts(model, tokenizer, sae, prompts, layer)
        For each prompt, returns the last-token sparse feature code (shape:
        len(prompts) x n_features).

  * top_features_for_prompt(...)
        The N strongest-firing features for a single prompt, with activation
        values. Equivalent to the read-only path in qwen_scope_steer.

  * differential_features(pos_codes, neg_codes, top_n)
        Given two stacks of feature codes, returns features ranked by
        (mean activation on positive set) - (mean activation on negative set).
        Useful for "which features distinguish concept A from concept B?"

  * scan_prompts_for_feature(codes, feature_id)
        Given a stack of codes and a feature id, returns the per-prompt
        activation values (zero where the feature didn't make TopK).

Nothing here generates steered text. Wire it up to the steering hooks in
qwen_scope_steer.py only when intervention is the explicit experimental goal,
and document the intervention separately.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

import torch

from qwen_scope_steer import SAE, capture_residual


@dataclass
class FeatureRanking:
    feature_id: int
    score: float           # (pos_mean - neg_mean) for differential, or activation for top-features
    pos_mean: float | None = None
    neg_mean: float | None = None
    pos_fire_rate: float | None = None
    neg_fire_rate: float | None = None


def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str],
                   layer_idx: int) -> torch.Tensor:
    """Encode the last-token residual of each prompt through the SAE.

    Returns codes of shape (len(prompts), n_features) on CPU float32 for
    stable downstream stats. No generation is performed.
    """
    codes = []
    for p in prompts:
        inputs = tokenizer(p, return_tensors="pt").to(model.device)
        with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
            model(**inputs)
        h_last = bucket["h"][0, -1].unsqueeze(0)
        z = sae.encode(h_last)[0]
        codes.append(z.detach().to("cpu", torch.float32))
    return torch.stack(codes, dim=0)


def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str,
                            layer_idx: int, top_n: int = 10) -> list[FeatureRanking]:
    codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0]
    nz = codes.nonzero(as_tuple=False).flatten()
    vals = codes[nz]
    order = vals.argsort(descending=True)[:top_n]
    return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i]))
            for i in order]


def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor,
                          top_n: int = 20) -> list[FeatureRanking]:
    """Rank features by their differential firing across two prompt sets.

    pos_codes: (P, F) feature codes for the "positive" prompt set
    neg_codes: (N, F) feature codes for the "negative" prompt set

    Returns top_n features by (pos_mean - neg_mean), with both means and
    per-set fire rates (fraction of prompts where the feature fired).
    Read-only — no generation, no steering.
    """
    if pos_codes.shape[1] != neg_codes.shape[1]:
        raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}")
    pos_mean = pos_codes.mean(dim=0)
    neg_mean = neg_codes.mean(dim=0)
    diff = pos_mean - neg_mean
    pos_fire = (pos_codes != 0).float().mean(dim=0)
    neg_fire = (neg_codes != 0).float().mean(dim=0)
    order = diff.argsort(descending=True)[:top_n]
    return [
        FeatureRanking(
            feature_id=int(i),
            score=float(diff[i]),
            pos_mean=float(pos_mean[i]),
            neg_mean=float(neg_mean[i]),
            pos_fire_rate=float(pos_fire[i]),
            neg_fire_rate=float(neg_fire[i]),
        )
        for i in order
    ]


def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor:
    """Per-prompt activation vector for a single feature (zero where it didn't make TopK)."""
    return codes[:, feature_id]


def fire_rate(codes: torch.Tensor, feature_id: int) -> float:
    """Fraction of prompts on which the feature fired (was in TopK)."""
    return float((codes[:, feature_id] != 0).float().mean())


def pretty_ranking(rs: list[FeatureRanking]) -> str:
    out = []
    for r in rs:
        if r.pos_mean is not None:
            out.append(
                f"  feat {r.feature_id:>6d}  "
                f"diff={r.score:+8.4f}  "
                f"pos_mean={r.pos_mean:+8.4f}  neg_mean={r.neg_mean:+8.4f}  "
                f"pos_fire={r.pos_fire_rate:.2f}  neg_fire={r.neg_fire_rate:.2f}"
            )
        else:
            out.append(f"  feat {r.feature_id:>6d}  act={r.score:+8.4f}")
    return "\n".join(out)