VoyagerXHF commited on
Commit
66ecb01
·
verified ·
1 Parent(s): a00f24d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1911 -0
app.py ADDED
@@ -0,0 +1,1911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — SAE Feature Explorer for Qwen3 models, whether pretrain (base) or posttrain (thinking/instruct) models.
3
+ """
4
+
5
+ import argparse
6
+ import html as _html
7
+ import json as _json
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import torch
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ # ─── CLI arguments ────────────────────────────────────────────────────────────
18
+ _parser = argparse.ArgumentParser(description="SAE Feature Explorer")
19
+ _parser.add_argument(
20
+ '--model',
21
+ default='Qwen/Qwen3.5-2B',
22
+ help='Path to the base model directory (default: %(default)s)',
23
+ )
24
+ _parser.add_argument(
25
+ '--model-name-sae-trained-from',
26
+ default='qwen3.5-2b-base',
27
+ help='The name of model which present representations for SAE training (default: %(default)s)',
28
+ )
29
+ _parser.add_argument(
30
+ '--model-name-analyzing-now',
31
+ default='qwen3.5-2b',
32
+ help='The name of model which is used for analyzing now (default: %(default)s)',
33
+ )
34
+ _parser.add_argument(
35
+ '--sae-path',
36
+ default='Qwen/SAE-Res-Qwen3.5-2B-Base-W32K-L0_100',
37
+ help='Path or HF Hub repo ID to the directory containing layer*.sae.pt files (default: %(default)s)',
38
+ )
39
+ _parser.add_argument(
40
+ '--top-k',
41
+ type=int,
42
+ default=100,
43
+ help='Number of top features to display (default: %(default)s)',
44
+ )
45
+ _parser.add_argument(
46
+ '--num-layers',
47
+ type=int,
48
+ default=24,
49
+ help='Number of transformer layers in the model (default: %(default)s)',
50
+ )
51
+ _parser.add_argument(
52
+ '--sae-width',
53
+ type=int,
54
+ default=32768,
55
+ help='SAE dictionary width / number of features (default: %(default)s)',
56
+ )
57
+ _parser.add_argument(
58
+ '--d-model',
59
+ type=int,
60
+ default=2048,
61
+ help='Model hidden dimension (default: %(default)s)',
62
+ )
63
+ _parser.add_argument(
64
+ '--sae-cache-max',
65
+ type=int,
66
+ default=8,
67
+ help='Maximum number of SAE layers to keep in memory at once (default: %(default)s)',
68
+ )
69
+ _parser.add_argument(
70
+ '--server-port',
71
+ type=int,
72
+ default=7860,
73
+ help='Port number for server',
74
+ )
75
+ _args, _unknown = _parser.parse_known_args()
76
+
77
+ # ─── Config ──────────────────────────────────────────────────────────────────
78
+ MODEL_PATH = _args.model
79
+ MODEL_NAME_SAE_TRAINED_FROM = _args.model_name_sae_trained_from
80
+ MODEL_NAME_ANALYZING_NOW = _args.model_name_analyzing_now
81
+ SAE_PATH = _args.sae_path
82
+ TOP_K = _args.top_k
83
+ NUM_LAYERS = _args.num_layers
84
+ SAE_WIDTH = _args.sae_width
85
+ D_MODEL = _args.d_model
86
+ SAE_CACHE_MAX = _args.sae_cache_max
87
+ PORT = _args.server_port
88
+
89
+ # ─── Generation defaults (from model's generation_config.json) ────────────────
90
+
91
+ _gen_cfg: dict = {}
92
+ _gen_cfg_path = os.path.join(MODEL_PATH, 'generation_config.json')
93
+ if os.path.exists(_gen_cfg_path):
94
+ with open(_gen_cfg_path) as _f:
95
+ _gen_cfg = _json.load(_f)
96
+ print(f"Loaded generation_config.json from {_gen_cfg_path}")
97
+ else:
98
+ print(f"No generation_config.json found at {_gen_cfg_path}; using built-in defaults.")
99
+
100
+ GEN_DO_SAMPLE = bool(_gen_cfg.get('do_sample', False))
101
+ GEN_TEMPERATURE = float(_gen_cfg.get('temperature', 1.0))
102
+ GEN_TOP_P = float(_gen_cfg.get('top_p', 1.0))
103
+ GEN_TOP_K = int(_gen_cfg.get('top_k', 1))
104
+ GEN_REP_PENALTY = float(_gen_cfg.get('repetition_penalty', 1.0))
105
+ STEER_DISPLAY_K = 10 # top-k candidates shown in the per-token probability panel
106
+
107
+ # ─── Default chat templates (thinking / no-thinking) ─────────────────────────
108
+
109
+ _THINK_TEMPLATE = (
110
+ "<|im_start|>user\n"
111
+ "{content}"
112
+ "<|im_end|>\n"
113
+ "<|im_start|>assistant\n"
114
+ "<think>\n"
115
+ )
116
+
117
+ _NOTHINK_TEMPLATE = (
118
+ "<|im_start|>user\n"
119
+ "{content}"
120
+ "<|im_end|>\n"
121
+ "<|im_start|>assistant\n"
122
+ "<think>\n\n</think>\n\n"
123
+ )
124
+
125
+ def apply_default_template(prompt: str, think: bool) -> str:
126
+ """Wrap *prompt* in the ChatML template for thinking or no-thinking mode."""
127
+ tpl = _THINK_TEMPLATE if think else _NOTHINK_TEMPLATE
128
+ return tpl.format(content=prompt.strip())
129
+
130
+ # ─── Device resolution ───────────────────────────────────────────────────────
131
+
132
+ def _resolve_sae_device() -> torch.device:
133
+ """
134
+ Pick the device for SAE weights and encoder/decoder computations.
135
+
136
+ CUDA_VISIBLE_DEVICES remaps physical GPUs so that the first listed GPU
137
+ always appears as cuda:0 inside this process. We simply use cuda:0
138
+ when any CUDA device is visible; fall back to CPU otherwise.
139
+ """
140
+ if not torch.cuda.is_available():
141
+ print("SAE device: cpu (no CUDA visible)")
142
+ return torch.device('cpu')
143
+ cvd = os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')
144
+ device = torch.device('cuda:0')
145
+ print(f"SAE device: {device} — {torch.cuda.get_device_name(device)}"
146
+ f" [CUDA_VISIBLE_DEVICES={cvd}]")
147
+ return device
148
+
149
+ SAE_DEVICE = _resolve_sae_device()
150
+
151
+ # ─── Global singletons ───────────────────────────────────────────────────────
152
+ _model = None
153
+ _tokenizer = None
154
+ _sae_lru: OrderedDict = OrderedDict()
155
+ _sae_local_dir: str | None = None # cached local dir for HF Hub downloaded SAEs
156
+ _orig_cache: dict | None = None # cached unsteered generation result
157
+
158
+
159
+ @spaces.GPU(duration=120)
160
+ def get_model():
161
+ global _model, _tokenizer
162
+ if _model is None:
163
+ print("Loading model…")
164
+ _model = AutoModelForCausalLM.from_pretrained(
165
+ MODEL_PATH, device_map='auto', torch_dtype='auto'
166
+ )
167
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
168
+ _model.eval()
169
+ print("Model ready.")
170
+ return _model, _tokenizer
171
+
172
+
173
+ def get_sae(layer: int) -> dict:
174
+ if layer in _sae_lru:
175
+ _sae_lru.move_to_end(layer)
176
+ return _sae_lru[layer]
177
+ if len(_sae_lru) >= SAE_CACHE_MAX:
178
+ _sae_lru.popitem(last=False)
179
+ # Support both local paths and HF Hub repo IDs
180
+ global _sae_local_dir
181
+ if os.path.isdir(SAE_PATH):
182
+ path = os.path.join(SAE_PATH, f'layer{layer}.sae.pt')
183
+ else:
184
+ # Assume HF Hub repo ID – download once, then read from local cache.
185
+ if _sae_local_dir is None:
186
+ _sae_local_dir = snapshot_download(SAE_PATH, cache_dir='./sae_cache', local_files_only=False)
187
+ path = os.path.join(_sae_local_dir, f'layer{layer}.sae.pt')
188
+ try:
189
+ sae = torch.load(path, map_location=SAE_DEVICE, weights_only=True)
190
+ except TypeError:
191
+ sae = torch.load(path, map_location=SAE_DEVICE)
192
+ # Pre-convert and transpose encoder weights once on load so compute_sae_features
193
+ # never repeats the conversion on every call.
194
+ sae['_W_enc'] = sae['W_enc'].T.to(dtype=torch.float32) # [d_model, sae_width]
195
+ sae['_b_enc'] = sae['b_enc'].to(dtype=torch.float32) # [sae_width]
196
+ _sae_lru[layer] = sae
197
+ return sae
198
+
199
+
200
+ # ─── Core math ───────────────────────────────────────────────────────────────
201
+
202
+ def topk_relu(x: torch.Tensor, k: int = TOP_K) -> torch.Tensor:
203
+ # Scatter top-k ReLU values directly — avoids creating a full-size boolean mask
204
+ # and an element-wise multiply, saving two [seq, SAE_WIDTH] allocations.
205
+ relu_x = torch.relu(x)
206
+ values, indices = torch.topk(relu_x, k, dim=-1)
207
+ out = torch.zeros_like(relu_x)
208
+ out.scatter_(-1, indices, values)
209
+ return out
210
+
211
+
212
+ @torch.no_grad()
213
+ def capture_hidden(model, input_ids: torch.Tensor, layer: int) -> torch.Tensor:
214
+ buf = {}
215
+ def _hook(module, inp, out):
216
+ # Qwen3MoE decoder layers return a plain tensor [batch, seq, hidden].
217
+ # out[0] removes the batch dim → [seq, hidden]; then move to SAE_DEVICE.
218
+ buf['h'] = out[0].detach().to(SAE_DEVICE, dtype=torch.float32)
219
+ handle = model.model.layers[layer].register_forward_hook(_hook)
220
+ model(input_ids)
221
+ handle.remove()
222
+ return buf['h'] # [seq_len, d_model]
223
+
224
+
225
+ @torch.no_grad()
226
+ def capture_all_hiddens(model, input_ids: torch.Tensor, layers: list) -> dict:
227
+ """
228
+ Capture residual-stream hidden states at multiple layers in a single
229
+ forward pass by registering simultaneous hooks. Tensors are stored on
230
+ SAE_DEVICE as float32 so downstream SAE matmuls need no extra transfer.
231
+ """
232
+ buf = {}
233
+ handles = []
234
+ for layer in layers:
235
+ def make_hook(l):
236
+ def _hook(module, inp, out):
237
+ buf[l] = out[0].detach().to(SAE_DEVICE, dtype=torch.float32)
238
+ return _hook
239
+ handles.append(model.model.layers[layer].register_forward_hook(make_hook(layer)))
240
+ model(input_ids)
241
+ for h in handles:
242
+ h.remove()
243
+ return buf # {layer_idx: Tensor[seq, d_model] on SAE_DEVICE}
244
+
245
+
246
+ def compute_sae_features(hidden: torch.Tensor, sae: dict,
247
+ raw: bool = False) -> torch.Tensor:
248
+ # Use pre-converted weights cached on load (avoids .float()/.T on every call)
249
+ W_enc = sae['_W_enc'] # [d_model, sae_width] float32 on SAE_DEVICE
250
+ b_enc = sae['_b_enc'] # [sae_width] float32 on SAE_DEVICE
251
+ pre = hidden @ W_enc + b_enc # [seq, sae_width] — pre-activation on SAE_DEVICE
252
+ if raw:
253
+ return pre # keep negative values intact; caller handles device
254
+ return topk_relu(pre, TOP_K) # stays on SAE_DEVICE; caller calls .tolist() as needed
255
+
256
+
257
+ # ─── UI helpers ──────────────────────────────────────────────────────────────
258
+
259
+
260
+ def parse_positions(s: str):
261
+ """
262
+ Parse a position string into 'all' or a sorted list of int indices.
263
+
264
+ Supported syntax (comma-separated, combinable):
265
+ all → every token position
266
+ 5 → single position
267
+ 3-7 → inclusive range (positions 3, 4, 5, 6, 7)
268
+ 0,2,5-8 → mix of individual positions and ranges
269
+ """
270
+ s = s.strip().lower()
271
+ if s == 'all':
272
+ return 'all'
273
+ try:
274
+ positions: list[int] = []
275
+ for part in s.split(','):
276
+ part = part.strip()
277
+ if not part:
278
+ continue
279
+ if '-' in part:
280
+ lo, hi = part.split('-', 1)
281
+ positions.extend(range(int(lo.strip()), int(hi.strip()) + 1))
282
+ else:
283
+ positions.append(int(part))
284
+ return sorted(set(positions))
285
+ except Exception:
286
+ return 'all'
287
+
288
+
289
+ def feature_heatmap_to_html(tokens: list, features: torch.Tensor, top_k: int, skip_first: bool = False) -> str:
290
+ """
291
+ Build a 2-D HTML heatmap:
292
+ rows = top-k features (ranked by mean activation across all positions)
293
+ cols = token positions
294
+ color = activation value (white → red, normalised per feature row by row max)
295
+ """
296
+
297
+ seq_len, sae_width = features.shape
298
+ top_k = min(int(top_k), sae_width)
299
+
300
+ # ── Optionally exclude the first token ────────────────────────────────────
301
+ if skip_first and seq_len > 1:
302
+ features = features[1:]
303
+ tokens = tokens[1:]
304
+ seq_len -= 1
305
+
306
+ # ── Select top-k features by mean activation across all positions ─────────
307
+ mean_per_feat = features.mean(dim=0) # [sae_width]
308
+ top_vals, top_idx = torch.topk(mean_per_feat, top_k)
309
+ feat_acts = features[:, top_idx] # [seq_len, top_k]
310
+
311
+ # ── Token column headers ──────────────────────────────────────────────────
312
+ TH_STYLE = (
313
+ "min-width:38px;max-width:70px;padding:4px 3px;"
314
+ "text-align:center;font-weight:500;font-size:11px;"
315
+ "color:#444;border-bottom:2px solid #c7d2e8;"
316
+ "overflow:hidden;white-space:nowrap;vertical-align:bottom;"
317
+ )
318
+ tok_headers = []
319
+ for i, tok in enumerate(tokens):
320
+ raw = tok.strip() or f"[{i}]"
321
+ short = _html.escape(raw[:6] + "…" if len(raw) > 6 else raw)
322
+ full = _html.escape(raw)
323
+ tok_headers.append(
324
+ f'<th style="{TH_STYLE}" title="pos {i}: {full}">{short}</th>'
325
+ )
326
+
327
+ # ── Data rows ─────────────────────────────────────────────────────────────
328
+ FEAT_TD = (
329
+ "font-family:ui-monospace,monospace;font-size:11px;"
330
+ "padding:3px 8px;color:#2563eb;white-space:nowrap;"
331
+ "border-right:2px solid #c7d2e8;background:#f8faff;"
332
+ "position:sticky;left:0;z-index:1;"
333
+ )
334
+ AVG_TD = (
335
+ "font-size:10px;padding:3px 6px;color:#777;white-space:nowrap;"
336
+ "border-right:1px solid #e4e7ef;text-align:right;"
337
+ )
338
+ CELL_BASE = (
339
+ "border:1px solid rgba(0,0,0,0.05);min-width:38px;height:30px;"
340
+ "text-align:center;vertical-align:middle;"
341
+ )
342
+
343
+ rows_html = []
344
+ for fi in range(top_k):
345
+ feat_i = int(top_idx[fi])
346
+ avg_val = float(top_vals[fi])
347
+ row_acts = feat_acts[:, fi] # [seq_len]
348
+ row_max = float(row_acts.max())
349
+ norm = row_max if row_max > 0 else 1.0
350
+
351
+ cells = []
352
+ for pos in range(seq_len):
353
+ v = float(row_acts[pos])
354
+ t = max(0.0, min(1.0, v / norm))
355
+ # white → amber → deep red
356
+ r = 255
357
+ g = int(255 * (1 - 0.8 * t))
358
+ b = int(255 * (1 - t))
359
+ cells.append(
360
+ f'<td style="{CELL_BASE}background:rgb({r},{g},{b});"'
361
+ f' title="feat #{feat_i} | pos {pos} | act={v:.4f}">'
362
+ f'</td>'
363
+ )
364
+
365
+ rows_html.append(
366
+ f'<tr>'
367
+ f'<td style="{FEAT_TD}">#{feat_i}</td>'
368
+ f'<td style="{AVG_TD}">{avg_val:.3f}</td>'
369
+ + "".join(cells)
370
+ + "</tr>"
371
+ )
372
+
373
+ # ── Assemble table ────────────────────────────────────────────────────────
374
+ header_row = (
375
+ '<tr>'
376
+ '<th style="padding:4px 8px;text-align:left;font-size:11px;font-weight:700;'
377
+ 'color:#2563eb;border-bottom:2px solid #c7d2e8;border-right:2px solid #c7d2e8;'
378
+ 'background:#f8faff;position:sticky;left:0;z-index:2;">Feature</th>'
379
+ '<th style="padding:4px 6px;font-size:11px;font-weight:700;color:#777;'
380
+ 'border-bottom:2px solid #c7d2e8;border-right:1px solid #e4e7ef;">'
381
+ 'Avg&nbsp;act.</th>'
382
+ + "".join(tok_headers)
383
+ + "</tr>"
384
+ )
385
+
386
+ legend = (
387
+ '<div style="display:flex;align-items:center;gap:10px;margin-top:10px;'
388
+ 'font-size:11px;color:#888;">'
389
+ '<span>0</span>'
390
+ '<div style="width:140px;height:12px;border-radius:6px;'
391
+ 'background:linear-gradient(to right,#fff,#ff6600,#cc0000);'
392
+ 'border:1px solid #ddd;"></div>'
393
+ '<span>peak activation (per-feature row-max scale)</span>'
394
+ '</div>'
395
+ )
396
+
397
+ return (
398
+ '<div style="overflow-x:auto;max-height:520px;overflow-y:auto;">'
399
+ '<table style="border-collapse:collapse;width:100%;'
400
+ 'font-family:ui-monospace,monospace;">'
401
+ f'<thead style="position:sticky;top:0;background:#fff;z-index:3;">'
402
+ f'{header_row}</thead>'
403
+ f'<tbody>{"".join(rows_html)}</tbody>'
404
+ '</table>'
405
+ '</div>'
406
+ + legend
407
+ )
408
+
409
+
410
+ def tokens_with_positions_html(tokens: list, positions) -> str:
411
+ """
412
+ Render tokenized prompt as coloured token chips.
413
+
414
+ Steered positions (amber/gold) are visually distinct from unsteered ones (grey).
415
+ positions: 'all' → every index is highlighted
416
+ list → only those indices
417
+ """
418
+
419
+ if not tokens:
420
+ return (
421
+ '<div style="padding:10px;color:#bbb;font-size:13px;">'
422
+ 'Enter a prompt above to preview token positions.</div>'
423
+ )
424
+
425
+ all_positions = positions if isinstance(positions, list) else []
426
+ pos_set = (
427
+ set(range(len(tokens))) if positions == 'all'
428
+ else {p for p in all_positions if 0 <= p < len(tokens)}
429
+ )
430
+ # Positions beyond the prompt — will be steered in the generated text
431
+ generated_positions = (
432
+ [] if positions == 'all'
433
+ else sorted(p for p in all_positions if p >= len(tokens))
434
+ )
435
+
436
+ parts = []
437
+ for i, tok in enumerate(tokens):
438
+ steered = i in pos_set
439
+ txt = _html.escape(tok)
440
+ title = _html.escape(repr(tok.strip()), quote=True)
441
+
442
+ if steered:
443
+ bg, border, text_color = "#fef3c7", "2px solid #f59e0b", "#92400e"
444
+ else:
445
+ bg, border, text_color = "#f1f5f9", "1px solid #e2e8f0", "#475569"
446
+
447
+ parts.append(
448
+ f'<span style="background:{bg};color:{text_color};'
449
+ f'padding:3px 7px;margin:2px 1px;border-radius:5px;'
450
+ f'display:inline-block;border:{border};'
451
+ f'font-family:ui-monospace,monospace;font-size:12px;" '
452
+ f'title="pos {i}: {title}">'
453
+ f'<sub style="opacity:.55;font-size:9px;margin-right:2px">{i}</sub>'
454
+ f'{txt}</span>'
455
+ )
456
+
457
+ n_steered = len(pos_set)
458
+ summary = (
459
+ f'<div style="margin-top:6px;font-size:11px;color:#888;">'
460
+ f'{len(tokens)}&nbsp;tokens total&nbsp;&nbsp;·&nbsp;&nbsp;'
461
+ f'<span style="color:#92400e;font-weight:600;">{n_steered}&nbsp;steered</span>'
462
+ f'&nbsp;<span style="color:#f59e0b;">■</span>'
463
+ f'</div>'
464
+ )
465
+
466
+ generated_note = ''
467
+ if generated_positions:
468
+ gp_str = ', '.join(str(p) for p in generated_positions)
469
+ generated_note = (
470
+ f'<div style="margin-top:4px;font-size:11px;padding:4px 8px;'
471
+ f'background:#eff6ff;border:1px solid #bfdbfe;border-radius:4px;color:#1d4ed8;">'
472
+ f'Positions {gp_str} are beyond the prompt — they will be steered '
473
+ f'in the <em>generated</em> text during autoregressive decoding.'
474
+ f'</div>'
475
+ )
476
+
477
+ return (
478
+ '<div style="padding:8px 4px;line-height:2.8;">'
479
+ + ' '.join(parts)
480
+ + summary
481
+ + generated_note
482
+ + '</div>'
483
+ )
484
+
485
+
486
+ def cb_feature_heatmap(state, top_k: int, skip_first: bool):
487
+ if state is None:
488
+ return (
489
+ '<div style="min-height:80px;display:flex;align-items:center;'
490
+ 'justify-content:center;color:#bbb;font-size:13px;">'
491
+ 'Run analysis first to see the feature heatmap.</div>'
492
+ )
493
+ tokens, features = state
494
+ return feature_heatmap_to_html(tokens, features, int(top_k), bool(skip_first))
495
+
496
+
497
+ # ─── Gradio callbacks ────────────────────────────────────────────────────────
498
+
499
+ @spaces.GPU(duration=120)
500
+ def cb_analyze(text: str, layer: int):
501
+ try:
502
+ model, tokenizer = get_model()
503
+ input_ids = tokenizer.encode(text, return_tensors='pt').to(
504
+ next(model.parameters()).device
505
+ )
506
+ tokens = [tokenizer.decode([t]) for t in input_ids[0].tolist()]
507
+ hidden = capture_hidden(model, input_ids, int(layer))
508
+ features = compute_sae_features(hidden, get_sae(int(layer)))
509
+ return (tokens, features)
510
+ except Exception as e:
511
+ raise gr.Error(f"Analysis failed: {e}")
512
+
513
+
514
+
515
+ def _steering_strength_from_mode(mode: str, diff_lookup, layer: int, feat_idx: int,
516
+ custom_val: float = 5.0) -> float:
517
+ """Map Light/Medium/Strong/Custom to an actual steering strength.
518
+
519
+ Looks up the feature-specific diff for (layer, feat_idx) from the
520
+ Feature Comparison results. Falls back to the global max across all
521
+ compared features, then to fixed defaults when no data is available.
522
+ """
523
+ if mode == "Custom":
524
+ return float(custom_val)
525
+ d = 0.0
526
+ if diff_lookup and isinstance(diff_lookup, dict):
527
+ key = (int(layer), int(feat_idx))
528
+ if key in diff_lookup:
529
+ d = float(diff_lookup[key])
530
+ else:
531
+ d = float(max(diff_lookup.values(), default=0.0))
532
+ if d <= 0:
533
+ return {"Light": 5.0, "Medium": 20.0, "Strong": 100.0}.get(mode, 5.0)
534
+ return {"Light": round(d * 0.5, 2),
535
+ "Medium": round(d * 2.0, 2),
536
+ "Strong": round(d * 10.0, 2)}.get(mode, round(d, 2))
537
+
538
+
539
+ @spaces.GPU(duration=120)
540
+ def cb_generate(prompt, layer, feat_idx, pos_str, steer_mode, compare_diff,
541
+ steer_output_only, max_tok, greedy, top_k_tok, top_p, rep_penalty, temp,
542
+ custom_strength=5.0, apply_think=False, apply_nothink=False):
543
+ try:
544
+ return _cb_generate_inner(prompt, layer, feat_idx, pos_str, steer_mode, compare_diff,
545
+ steer_output_only, max_tok, greedy, top_k_tok, top_p, rep_penalty, temp,
546
+ custom_strength, apply_think, apply_nothink)
547
+ except gr.Error:
548
+ raise
549
+ except Exception as e:
550
+ raise gr.Error(f"Generation failed: {e}")
551
+
552
+
553
+ def cb_update_steer_preview(prompt: str, pos_str: str,
554
+ apply_think: bool = False, apply_nothink: bool = False):
555
+ """Tokenise the prompt and return an HTML token-position preview."""
556
+ if not prompt.strip():
557
+ return (
558
+ '<div style="padding:10px;color:#bbb;font-size:13px;">'
559
+ 'Enter a prompt above to preview steered positions.</div>'
560
+ )
561
+ try:
562
+ _, tokenizer = get_model()
563
+ if apply_think:
564
+ effective = apply_default_template(prompt, think=True)
565
+ elif apply_nothink:
566
+ effective = apply_default_template(prompt, think=False)
567
+ else:
568
+ effective = prompt
569
+ input_ids = tokenizer.encode(effective)
570
+ tokens = [tokenizer.decode([t]) for t in input_ids]
571
+ positions = parse_positions(pos_str)
572
+ return tokens_with_positions_html(tokens, positions)
573
+ except Exception as e:
574
+ return (
575
+ f'<div style="padding:10px;color:#dc2626;font-size:13px;">'
576
+ f'Preview error: {e}</div>'
577
+ )
578
+
579
+
580
+ def _extract_probs(gen_out, input_len: int, tokenizer, display_k: int):
581
+ """
582
+ Extract per-step token probabilities from a `return_dict_in_generate=True,
583
+ output_scores=True` GenerateOutput.
584
+
585
+ Returns (text, tokens, chosen_probs, topk_data) where:
586
+ tokens : list[str] — decoded token strings
587
+ chosen_probs : list[float] — probability of the chosen token (0-1)
588
+ topk_data : list[list[[str, float, bool]]] — top-k candidates at each step,
589
+ each entry is [token_str, prob, is_chosen]
590
+ """
591
+ new_ids = gen_out.sequences[0][input_len:]
592
+ new_id_list = new_ids.tolist()
593
+
594
+ # Batch-decode chosen tokens and all top-k candidates in two passes
595
+ # instead of O(n * display_k) individual tokenizer.decode() calls.
596
+ all_topk_ids: list[list[int]] = []
597
+ chosen_probs: list[float] = []
598
+ topk_vals_list: list = []
599
+ chosen_in_top_list: list[bool]= []
600
+
601
+ for score_t, tok_id in zip(gen_out.scores, new_id_list):
602
+ probs = torch.softmax(score_t[0].float(), dim=-1)
603
+ chosen_probs.append(float(probs[tok_id]))
604
+ top_vals, top_ids = torch.topk(probs, display_k)
605
+ tid_list = top_ids.tolist()
606
+ chosen_in_top = tok_id in tid_list
607
+ all_topk_ids.append(tid_list)
608
+ topk_vals_list.append(top_vals.tolist())
609
+ chosen_in_top_list.append(chosen_in_top)
610
+
611
+ # Single batch_decode call for all chosen tokens
612
+ tokens: list[str] = tokenizer.batch_decode(
613
+ [[t] for t in new_id_list], skip_special_tokens=False
614
+ )
615
+
616
+ # Single batch_decode call for all top-k candidate tokens
617
+ flat_ids = [tid for ids in all_topk_ids for tid in ids]
618
+ flat_decoded = tokenizer.batch_decode(
619
+ [[t] for t in flat_ids], skip_special_tokens=False
620
+ )
621
+
622
+ topk_data = []
623
+ flat_idx = 0
624
+ for i, (tok_id, ids, vals, chosen_in_top, chosen_prob) in enumerate(
625
+ zip(new_id_list, all_topk_ids, topk_vals_list, chosen_in_top_list, chosen_probs)
626
+ ):
627
+ entry = []
628
+ for tid, tv in zip(ids, vals):
629
+ entry.append([flat_decoded[flat_idx], tv, tid == tok_id])
630
+ flat_idx += 1
631
+ if not chosen_in_top:
632
+ entry.append([tokens[i], chosen_prob, True])
633
+ topk_data.append(entry)
634
+
635
+ text = tokenizer.decode(new_ids, skip_special_tokens=True)
636
+ return text, tokens, chosen_probs, topk_data
637
+
638
+
639
+ def probs_to_html(tokens: list, chosen_probs: list, topk_data: list,
640
+ panel_id: str, theme: str = 'blue') -> str:
641
+ """
642
+ Render per-token generation probabilities as coloured chips.
643
+ Clicking a chip pins its top-k candidate table in the panel below;
644
+ clicking the same chip again or another chip toggles/switches the display.
645
+ Scroll-stable: no hover events that fire on page scroll.
646
+
647
+ theme: 'blue' for original output, 'red' for steered output.
648
+ """
649
+
650
+ if not tokens:
651
+ return ('<div style="padding:10px;color:#bbb;font-size:13px;">'
652
+ 'No tokens generated.</div>')
653
+
654
+ # ── Chip colour (white → saturated) based on probability ─────────────────
655
+ def _colors(prob: float):
656
+ t = max(0.0, min(1.0, prob))
657
+ if theme == 'blue':
658
+ r, g, b = int(255 * (1 - t * 0.85)), int(255 * (1 - t * 0.65)), 255
659
+ txt = '#1e3a8a' if t < 0.55 else '#ffffff'
660
+ else:
661
+ r, g, b = 255, int(255 * (1 - t * 0.82)), int(255 * (1 - t))
662
+ txt = '#7f1d1d' if t < 0.55 else '#ffffff'
663
+ return f'rgb({r},{g},{b})', txt
664
+
665
+ # ── Pre-build the top-k panel HTML in Python ──────────────────────────────
666
+ TH = 'padding:2px 8px;font-size:11px;color:#6b7280;border-bottom:1px solid #e4e7ef;'
667
+
668
+ def _panel_html(entry: list) -> str:
669
+ rows = []
670
+ for rank, (tok_str, prob, is_chosen) in enumerate(entry, 1):
671
+ bg = 'background:#dbeafe;' if is_chosen else ''
672
+ fw = 'font-weight:700;' if is_chosen else ''
673
+ mk = ' ✓' if is_chosen else ''
674
+ rows.append(
675
+ f'<tr style="border-bottom:1px solid #f4f6ff;{bg}">'
676
+ f'<td style="padding:2px 8px;text-align:right;font-size:11px;color:#9ca3af;">{rank}</td>'
677
+ f'<td style="padding:2px 8px;font-family:monospace;font-size:12px;{fw}">{_html.escape(tok_str)}{mk}</td>'
678
+ f'<td style="padding:2px 8px;text-align:right;font-family:monospace;font-size:12px;">{prob:.4f}</td>'
679
+ f'<td style="padding:2px 8px;text-align:right;font-family:monospace;font-size:12px;">{prob * 100:.2f}%</td>'
680
+ f'</tr>'
681
+ )
682
+ return (
683
+ '<table style="border-collapse:collapse;width:100%;font-size:12px;">'
684
+ f'<thead style="background:#f8faff;"><tr>'
685
+ f'<th style="{TH}text-align:right;">Rank</th>'
686
+ f'<th style="{TH}text-align:left;">Token</th>'
687
+ f'<th style="{TH}text-align:right;">Prob</th>'
688
+ f'<th style="{TH}text-align:right;">%</th>'
689
+ f'</tr></thead>'
690
+ f'<tbody>{"".join(rows)}</tbody>'
691
+ '</table>'
692
+ )
693
+
694
+ # ── Inline JS — click to pin, click again to unpin ───────────────────────
695
+ # Uses data-prob-root to scope sibling chips without global IDs.
696
+ # Single-quoted JS string literals are safe inside double-quoted HTML attrs.
697
+ # Non-f-string parts: { } are literal characters (no f-string substitution).
698
+ JS_CLICK = (
699
+ "var root=this.closest('[data-prob-root]');"
700
+ "if(!root)return;"
701
+ "var p=root.querySelector('[data-topk-panel]');"
702
+ "if(!p)return;"
703
+ "var sel=this.dataset.selected==='1';"
704
+ "root.querySelectorAll('[data-chip]').forEach(function(e){"
705
+ "e.dataset.selected='0';e.style.outline='';});"
706
+ "if(sel){"
707
+ "p.innerHTML='';p.style.display='none';"
708
+ "}else{"
709
+ "this.dataset.selected='1';"
710
+ "this.style.outline='2px solid #94a3b8';"
711
+ "this.style.outlineOffset='-1px';"
712
+ "p.innerHTML=this.getAttribute('data-panel');"
713
+ "p.style.display='block';"
714
+ "}"
715
+ )
716
+
717
+ def _tok_disp(s: str) -> str:
718
+ return s.replace('\n', '↵').replace('\r', '↵').replace('\t', '→')
719
+
720
+ # ── Build chips ───────────────────────────────────────────────────────────
721
+ chips = []
722
+ for tok, prob, entry in zip(tokens, chosen_probs, topk_data):
723
+ bg, txt = _colors(prob)
724
+ panel_attr = _html.escape(_panel_html(entry), quote=True)
725
+ chips.append(
726
+ f'<span data-chip data-selected="0" '
727
+ f'style="background:{bg};color:{txt};padding:3px 8px 2px;margin:1px;'
728
+ f'border-radius:5px;display:inline-block;cursor:pointer;white-space:nowrap;'
729
+ f'font-family:ui-monospace,monospace;font-size:12px;" '
730
+ f'data-panel="{panel_attr}" '
731
+ f'onclick="{JS_CLICK}">'
732
+ f'{_html.escape(_tok_disp(tok))}'
733
+ f'<sub style="opacity:.75;font-size:9px;margin-left:3px;">{prob * 100:.1f}%</sub>'
734
+ f'</span>'
735
+ )
736
+
737
+ return (
738
+ '<div data-prob-root style="padding:2px;">'
739
+ '<div style="font-size:11px;color:#888;margin-bottom:6px;font-style:italic;">'
740
+ 'Click a token to pin its top-k candidates &nbsp;·&nbsp; click again to dismiss.</div>'
741
+ '<div style="padding:4px;line-height:2.8;">'
742
+ + ''.join(chips)
743
+ + '</div>'
744
+ + '<div data-topk-panel style="display:none;margin-top:8px;padding:4px;'
745
+ 'background:#f8faff;border:1px solid #e4e7ef;border-radius:6px;'
746
+ 'max-height:220px;overflow-y:auto;"></div>'
747
+ + '</div>'
748
+ )
749
+
750
+
751
+ def _cb_generate_inner(prompt, layer, feat_idx, pos_str, steer_mode, compare_diff,
752
+ steer_output_only, max_tok, greedy, top_k_tok, top_p, rep_penalty, temp,
753
+ custom_strength=5.0, apply_think=False, apply_nothink=False):
754
+ global _orig_cache
755
+ model, tokenizer = get_model()
756
+ layer = int(layer)
757
+ feat_idx = int(feat_idx)
758
+ if not (0 <= feat_idx < SAE_WIDTH):
759
+ raise gr.Error(f"Feature index must be in [0, {SAE_WIDTH - 1}].")
760
+ strength = _steering_strength_from_mode(steer_mode, compare_diff, layer, feat_idx, float(custom_strength))
761
+ positions = parse_positions(pos_str)
762
+
763
+ if apply_think:
764
+ effective_prompt = apply_default_template(prompt, think=True)
765
+ elif apply_nothink:
766
+ effective_prompt = apply_default_template(prompt, think=False)
767
+ else:
768
+ effective_prompt = prompt
769
+ input_ids = tokenizer.encode(effective_prompt, return_tensors='pt').to(
770
+ next(model.parameters()).device
771
+ )
772
+
773
+ # Build generation kwargs shared by both calls
774
+ gen_kwargs: dict = dict(max_new_tokens=int(max_tok),
775
+ return_dict_in_generate=True, output_scores=True)
776
+ if greedy:
777
+ gen_kwargs['do_sample'] = False
778
+ else:
779
+ gen_kwargs['do_sample'] = True
780
+ gen_kwargs['temperature'] = float(temp)
781
+ gen_kwargs['top_k'] = int(top_k_tok)
782
+ gen_kwargs['top_p'] = float(top_p)
783
+ gen_kwargs['repetition_penalty'] = float(rep_penalty)
784
+
785
+ prompt_len = input_ids.shape[1]
786
+
787
+ # ── Original generation (cached) ─────────────────────────────────────────
788
+ # The unsteered output depends only on the prompt and decoding parameters,
789
+ # not on any steering inputs. Reuse the last result when those are unchanged.
790
+ if greedy:
791
+ orig_key = (effective_prompt, int(max_tok), True)
792
+ else:
793
+ orig_key = (effective_prompt, int(max_tok), False,
794
+ int(top_k_tok), float(top_p), float(rep_penalty), float(temp))
795
+
796
+ if _orig_cache is not None and _orig_cache['key'] == orig_key:
797
+ orig_text = _orig_cache['text']
798
+ orig_probs_html = _orig_cache['probs_html']
799
+ else:
800
+ with torch.no_grad():
801
+ orig_out = model.generate(input_ids, **gen_kwargs)
802
+ orig_text, orig_toks, orig_probs, orig_topk = _extract_probs(
803
+ orig_out, prompt_len, tokenizer, STEER_DISPLAY_K
804
+ )
805
+ orig_probs_html = probs_to_html(orig_toks, orig_probs, orig_topk,
806
+ 'topk-panel-orig', theme='blue')
807
+ _orig_cache = dict(key=orig_key, text=orig_text, probs_html=orig_probs_html)
808
+
809
+ sae = get_sae(layer)
810
+ steering_vec = sae['W_dec'][:, feat_idx].float() # [d_model]
811
+ pos_set = None if positions == 'all' else set(positions)
812
+ counter = [0]
813
+
814
+ def _steer_hook(module, inp, out):
815
+ # out: plain tensor [batch, seq, hidden] for Qwen3MoE
816
+ h = out.clone()
817
+ sv = steering_vec.to(device=h.device, dtype=h.dtype) # one fused transfer
818
+ cur_counter = counter[0]
819
+ counter[0] += 1
820
+ if cur_counter == 0:
821
+ # Prefill: apply position-based steering to the prompt
822
+ if positions == 'all':
823
+ h = h + strength * sv
824
+ else:
825
+ for p in positions:
826
+ if 0 <= p < h.shape[1]:
827
+ h[:, p, :] = h[:, p, :] + strength * sv
828
+ else:
829
+ # Decode step (KV-cache): h is [batch, 1, hidden]
830
+ # Steer if: output-only mode is on, positions='all', or this position is listed
831
+ cur_seq_pos = prompt_len + cur_counter - 1
832
+ if steer_output_only or positions == 'all' or cur_seq_pos in pos_set:
833
+ h[:, 0, :] = h[:, 0, :] + strength * sv
834
+ return h
835
+
836
+ handle = model.model.layers[layer].register_forward_hook(_steer_hook)
837
+ with torch.no_grad():
838
+ steer_out = model.generate(input_ids, **gen_kwargs)
839
+ handle.remove()
840
+ steer_text, steer_toks, steer_probs, steer_topk = _extract_probs(
841
+ steer_out, prompt_len, tokenizer, STEER_DISPLAY_K
842
+ )
843
+
844
+ steer_probs_html = probs_to_html(steer_toks, steer_probs, steer_topk,
845
+ 'topk-panel-steer', theme='red')
846
+
847
+ return orig_text, steer_text, orig_probs_html, steer_probs_html
848
+
849
+
850
+
851
+ # ─── Feature Comparison helpers ──────────────────────────────────────────────
852
+
853
+ def compare_to_html(records: list, text1: str, text2: str,
854
+ tokens1: list = None, tokens2: list = None) -> tuple:
855
+ """
856
+ Render comparison results as two HTML strings:
857
+ - tok_display_html: token rows for the left panel (data-tok-display root)
858
+ - feature_table_html: feature table for the right panel
859
+
860
+ Returns (tok_display_html, feature_table_html).
861
+ """
862
+
863
+ _TOK_PLACEHOLDER = (
864
+ '<div style="min-height:60px;display:flex;align-items:center;'
865
+ 'justify-content:center;color:#bbb;font-size:13px;padding:8px;">'
866
+ 'Run Compare to see token activations here.</div>'
867
+ )
868
+
869
+ if not records:
870
+ return (
871
+ _TOK_PLACEHOLDER,
872
+ '<div style="min-height:80px;display:flex;align-items:center;'
873
+ 'justify-content:center;color:#bbb;font-size:13px;">'
874
+ 'No results — try a wider layer range or larger Top-K.</div>',
875
+ )
876
+
877
+ # ── Token display blocks ──────────────────────────────────────────────────
878
+ TOK_SPAN = (
879
+ "display:inline-block;padding:3px 7px;margin:2px 1px;"
880
+ "border-radius:5px;font-family:ui-monospace,monospace;font-size:12px;"
881
+ "background:#eef2ff;color:#374151;cursor:default;"
882
+ "transition:background .1s;border:1px solid rgba(0,0,0,0.06);"
883
+ )
884
+
885
+ def render_tok_row(tokens, seq_id):
886
+ parts = []
887
+ for i, tok in enumerate(tokens):
888
+ txt = _html.escape(tok)
889
+ title = _html.escape(repr(tok.strip()), quote=True)
890
+ parts.append(
891
+ f'<span data-seq={seq_id} data-pos={i} style="{TOK_SPAN}" '
892
+ f'title="pos {i}: {title}">'
893
+ f'<sub style="opacity:.5;font-size:9px;margin-right:2px">{i}</sub>'
894
+ f'{txt}</span>'
895
+ )
896
+ return " ".join(parts)
897
+
898
+ # Build token display HTML for the left panel
899
+ if tokens1 and tokens2:
900
+ tok_inner = (
901
+ '<div style="margin-bottom:10px;color:#1e293b;">'
902
+ '<div style="font-size:11px;font-weight:700;color:#2563eb;'
903
+ 'text-transform:uppercase;letter-spacing:.5px;margin-bottom:5px;">'
904
+ f'Example 1 &nbsp;<span style="font-weight:400;color:#888;">'
905
+ f'({len(tokens1)} tokens)</span></div>'
906
+ '<div style="line-height:2.8;padding:8px 10px;background:#fafbff;'
907
+ 'border-radius:8px;border:1px solid #e4e7ef;overflow-x:auto;">'
908
+ + render_tok_row(tokens1, 1)
909
+ + '</div></div>'
910
+ '<div style="margin-bottom:8px;color:#1e293b;">'
911
+ '<div style="font-size:11px;font-weight:700;color:#dc2626;'
912
+ 'text-transform:uppercase;letter-spacing:.5px;margin-bottom:5px;">'
913
+ f'Example 2 &nbsp;<span style="font-weight:400;color:#888;">'
914
+ f'({len(tokens2)} tokens)</span></div>'
915
+ '<div style="line-height:2.8;padding:8px 10px;background:#fafbff;'
916
+ 'border-radius:8px;border:1px solid #e4e7ef;overflow-x:auto;">'
917
+ + render_tok_row(tokens2, 2)
918
+ + '</div></div>'
919
+ '<div style="font-size:11px;color:#888;font-style:italic;">'
920
+ 'Hover a feature row on the right to highlight activations.</div>'
921
+ )
922
+ else:
923
+ tok_inner = _TOK_PLACEHOLDER
924
+
925
+ # Wrap with data-tok-display so the JS hover handler can find it across columns
926
+ tok_display_html = f'<div data-tok-display style="padding:2px;">{tok_inner}</div>'
927
+
928
+ # ── Per-layer max for bar-width normalization ─────────────────────────────
929
+ _layer_max: dict = {}
930
+ for _d, _l, *_ in records:
931
+ if _d > _layer_max.get(_l, 0.0):
932
+ _layer_max[_l] = _d
933
+
934
+ # ── Inline JS snippets for hover-highlight ────────────────────────────────
935
+ # Uses document.querySelector('[data-tok-display]') so the handler works
936
+ # across Gradio columns (token panel on left, feature table on right).
937
+ _JS_ENTER = (
938
+ "var d=document.querySelector('[data-tok-display]');"
939
+ "if(!d)return;"
940
+ "var a1=JSON.parse(this.getAttribute('data-acts1'));"
941
+ "var a2=JSON.parse(this.getAttribute('data-acts2'));"
942
+ "if(!a1||!a2)return;"
943
+ "var pk=Math.max.apply(null,a1.map(Math.abs).concat(a2.map(Math.abs)))||0.0001;"
944
+ "function c1(v){var t=Math.abs(v)/pk;"
945
+ "return 'rgb('+Math.round(255*(1-t))+','+Math.round(255*(1-.6*t))+',255)'}"
946
+ "function c2(v){var t=Math.abs(v)/pk;"
947
+ "return 'rgb(255,'+Math.round(255*(1-.8*t))+','+Math.round(255*(1-t))+')'}"
948
+ "d.querySelectorAll('[data-seq]').forEach(function(e){"
949
+ "var s=e.dataset.seq,p=parseInt(e.dataset.pos,10);"
950
+ "if(s==='1'&&p<a1.length)e.style.background=c1(a1[p]);"
951
+ "else if(s==='2'&&p<a2.length)e.style.background=c2(a2[p]);});"
952
+ "this.style.outline='2px solid #94a3b8';"
953
+ "this.style.outlineOffset='-1px';"
954
+ )
955
+ _JS_LEAVE = (
956
+ "var d=document.querySelector('[data-tok-display]');"
957
+ "if(!d)return;"
958
+ "d.querySelectorAll('[data-seq]').forEach(function(e){e.style.background='';});"
959
+ "this.style.outline='';"
960
+ )
961
+
962
+ TR_BASE = "border-bottom:1px solid #f0f4ff;"
963
+ TH = (
964
+ "padding:6px 10px;font-size:11px;font-weight:700;text-transform:uppercase;"
965
+ "letter-spacing:.5px;white-space:nowrap;"
966
+ )
967
+
968
+ rows_html = []
969
+ current_layer = None
970
+ layer_rank = 0
971
+ for _rank, record in enumerate(records, 1):
972
+ diff_val, layer, feat_idx, act1, act2 = record[:5]
973
+ acts1_pos = record[5] if len(record) > 5 else None
974
+ acts2_pos = record[6] if len(record) > 6 else None
975
+
976
+ # Insert a layer-group header row whenever the layer changes
977
+ if layer != current_layer:
978
+ current_layer = layer
979
+ layer_rank = 0
980
+ rows_html.append(
981
+ f'<tr style="background:#eef2ff;border-top:2px solid #c7d2e8;">'
982
+ f'<td colspan="6" style="padding:4px 12px;font-size:11px;font-weight:700;'
983
+ f'color:#2563eb;letter-spacing:.5px;">Layer {layer}</td>'
984
+ f'</tr>'
985
+ )
986
+ layer_rank += 1
987
+
988
+ bar_w = max(2, int(120 * diff_val / (_layer_max.get(layer) or 1.0)))
989
+ if act1 >= act2:
990
+ bar_color = "#2563eb"
991
+ dir_label = "Ex&nbsp;1&nbsp;▲"
992
+ dir_color = "#2563eb"
993
+ row_bg = "background:#f5f8ff;"
994
+ else:
995
+ bar_color = "#dc2626"
996
+ dir_label = "Ex&nbsp;2&nbsp;▲"
997
+ dir_color = "#dc2626"
998
+ row_bg = "background:#fff5f5;"
999
+
1000
+ # Embed per-position activation arrays for the hover handler
1001
+ if acts1_pos is not None and acts2_pos is not None:
1002
+ a1_json = _json.dumps(acts1_pos)
1003
+ a2_json = _json.dumps(acts2_pos)
1004
+ tr_open = (
1005
+ f"<tr style='{TR_BASE}{row_bg}cursor:pointer;'"
1006
+ f" data-acts1='{a1_json}'"
1007
+ f" data-acts2='{a2_json}'"
1008
+ f' onmouseenter="{_JS_ENTER}"'
1009
+ f' onmouseleave="{_JS_LEAVE}">'
1010
+ )
1011
+ else:
1012
+ tr_open = f'<tr style="{TR_BASE}{row_bg}">'
1013
+
1014
+ rows_html.append(
1015
+ tr_open
1016
+ + f'<td style="padding:5px 10px;text-align:center;color:#9ca3af;font-size:11px;">{layer_rank}</td>'
1017
+ + f'<td style="padding:5px 10px;font-family:monospace;color:#2563eb;">#{feat_idx}</td>'
1018
+ + f'<td style="padding:5px 8px;text-align:right;font-family:monospace;color:#374151;">{act1:.1%}</td>'
1019
+ + f'<td style="padding:5px 8px;text-align:right;font-family:monospace;color:#374151;">{act2:.1%}</td>'
1020
+ + f'<td style="padding:5px 10px;">'
1021
+ + f' <div style="display:flex;align-items:center;gap:6px;">'
1022
+ + f' <div style="width:{bar_w}px;height:10px;background:{bar_color};'
1023
+ + f' border-radius:3px;flex-shrink:0;"></div>'
1024
+ + f' <span style="font-family:monospace;font-size:12px;color:#374151;">{diff_val:.1%}</span>'
1025
+ + f' </div>'
1026
+ + f'</td>'
1027
+ + f'<td style="padding:5px 10px;font-size:11px;font-weight:700;color:{dir_color};">'
1028
+ + f'{dir_label}</td>'
1029
+ + '</tr>'
1030
+ )
1031
+
1032
+ ex1_short = _html.escape(text1[:50] + "…" if len(text1) > 50 else text1)
1033
+ ex2_short = _html.escape(text2[:50] + "…" if len(text2) > 50 else text2)
1034
+
1035
+ legend = (
1036
+ '<div style="display:flex;flex-wrap:wrap;gap:16px;margin-top:12px;'
1037
+ 'font-size:11px;color:#6b7280;">'
1038
+ f'<span><span style="color:#2563eb;font-weight:700;">■ Ex 1</span>'
1039
+ f' "{ex1_short}"</span>'
1040
+ f'<span><span style="color:#dc2626;font-weight:700;">■ Ex 2</span>'
1041
+ f' "{ex2_short}"</span>'
1042
+ '</div>'
1043
+ )
1044
+
1045
+ table_inner = (
1046
+ '<div style="overflow-x:auto;max-height:560px;overflow-y:auto;color:#1e293b;">'
1047
+ '<table style="border-collapse:collapse;width:100%;color:#1e293b;'
1048
+ 'font-family:ui-monospace,monospace;font-size:13px;">'
1049
+ '<thead style="background:#f8faff;color:#1e293b;border-bottom:2px solid #c7d2e8;'
1050
+ 'position:sticky;top:0;z-index:2;">'
1051
+ '<tr>'
1052
+ f'<th style="{TH}color:#9ca3af;">Rank</th>'
1053
+ f'<th style="{TH}color:#2563eb;">Feature</th>'
1054
+ f'<th style="{TH}color:#2563eb;text-align:right;">Rate&nbsp;Ex&nbsp;1</th>'
1055
+ f'<th style="{TH}color:#dc2626;text-align:right;">Rate&nbsp;Ex&nbsp;2</th>'
1056
+ f'<th style="{TH}color:#6b7280;">|Rate diff|</th>'
1057
+ f'<th style="{TH}color:#6b7280;">Higher</th>'
1058
+ '</tr>'
1059
+ '</thead>'
1060
+ f'<tbody>{"".join(rows_html)}</tbody>'
1061
+ '</table>'
1062
+ '</div>'
1063
+ )
1064
+
1065
+ feature_table_html = (
1066
+ '<div style="padding:2px;">'
1067
+ + table_inner
1068
+ + legend
1069
+ + '</div>'
1070
+ )
1071
+
1072
+ return tok_display_html, feature_table_html
1073
+
1074
+
1075
+ @spaces.GPU(duration=180)
1076
+ def cb_compare(text1: str, text2: str, layer_from: int, layer_to: int,
1077
+ top_k: int, skip_first: bool,
1078
+ remove_common_toks: bool, remove_prefix: bool,
1079
+ raw_acts: bool = False):
1080
+ try:
1081
+ if not text1.strip() or not text2.strip():
1082
+ raise gr.Error("Both examples must be non-empty.")
1083
+
1084
+ model, tokenizer = get_model()
1085
+ layer_from = int(layer_from)
1086
+ layer_to = int(layer_to)
1087
+ top_k = int(top_k)
1088
+ if layer_from > layer_to:
1089
+ layer_from, layer_to = layer_to, layer_from
1090
+ layers = list(range(layer_from, layer_to + 1))
1091
+
1092
+ # ── Tokenise ─────────────────────────────────────────────────────────
1093
+ model_dev = next(model.parameters()).device
1094
+ ids1 = tokenizer.encode(text1, return_tensors='pt').to(model_dev)
1095
+ ids2 = tokenizer.encode(text2, return_tensors='pt').to(model_dev)
1096
+ toks1 = ids1[0].tolist()
1097
+ toks2 = ids2[0].tolist()
1098
+
1099
+ # ── Build per-sequence keep-index lists ───────────────────────────────
1100
+ prefix_len = 0
1101
+ if remove_prefix:
1102
+ for a, b in zip(toks1, toks2):
1103
+ if a == b:
1104
+ prefix_len += 1
1105
+ else:
1106
+ break
1107
+
1108
+ common_tok_ids: set = set()
1109
+ if remove_common_toks:
1110
+ common_tok_ids = set(toks1) & set(toks2)
1111
+
1112
+ def _build_keep(toks: list) -> list:
1113
+ return [
1114
+ i for i, t in enumerate(toks)
1115
+ if not (skip_first and i == 0)
1116
+ and i >= prefix_len
1117
+ and t not in common_tok_ids
1118
+ ]
1119
+
1120
+ keep1 = _build_keep(toks1)
1121
+ keep2 = _build_keep(toks2)
1122
+
1123
+ # ── Capture hidden states for all layers in two forward passes ────────
1124
+ hiddens1 = capture_all_hiddens(model, ids1, layers)
1125
+ hiddens2 = capture_all_hiddens(model, ids2, layers)
1126
+
1127
+ # Decoded token strings for the HTML token display
1128
+ tokens1_str = [tokenizer.decode([t]) for t in toks1]
1129
+ tokens2_str = [tokenizer.decode([t]) for t in toks2]
1130
+
1131
+ # ── Per-layer feature activation-rate difference ──────────────────────
1132
+ # Activation rate = fraction of kept positions where the feature fires
1133
+ # (activation > 0). Ranking by |rate1 − rate2| highlights features
1134
+ # that are selectively active in one example but not the other.
1135
+ # Load one SAE at a time to avoid OOM (each SAE is ~1-2 GB on GPU).
1136
+ candidates = [] # (abs_diff, layer, feat_idx, rate1, rate2,
1137
+ # acts1_per_pos, acts2_per_pos)
1138
+ for layer in layers:
1139
+ sae = get_sae(layer)
1140
+
1141
+ # Full per-position feature activations — stay on SAE_DEVICE for GPU math
1142
+ feats1 = compute_sae_features(hiddens1[layer], sae, raw=raw_acts) # [seq1_len, SAE_WIDTH]
1143
+ feats2 = compute_sae_features(hiddens2[layer], sae, raw=raw_acts) # [seq2_len, SAE_WIDTH]
1144
+
1145
+ # Activation rate = fraction of kept positions where feature fires (> 0)
1146
+ def _rate(feats: torch.Tensor, keep_idx: list) -> torch.Tensor:
1147
+ if not keep_idx:
1148
+ return torch.zeros(feats.shape[1], device=feats.device, dtype=feats.dtype)
1149
+ return (feats[keep_idx] > 0).float().mean(dim=0)
1150
+
1151
+ r1 = _rate(feats1, keep1)
1152
+ r2 = _rate(feats2, keep2)
1153
+ diff = (r1 - r2).abs()
1154
+
1155
+ # Top-k per layer (all kept — no global trim)
1156
+ local_k = min(top_k, SAE_WIDTH)
1157
+ vals, idxs = torch.topk(diff, local_k)
1158
+ for v, fi in zip(vals.tolist(), idxs.tolist()):
1159
+ # Round to 3 dp — enough precision for color interpolation
1160
+ a1_pos = [round(x, 3) for x in feats1[:, fi].tolist()]
1161
+ a2_pos = [round(x, 3) for x in feats2[:, fi].tolist()]
1162
+ candidates.append((v, layer, fi, float(r1[fi]), float(r2[fi]),
1163
+ a1_pos, a2_pos))
1164
+
1165
+ # Free SAE weights and feature tensors before loading the next layer
1166
+ del sae, feats1, feats2, diff
1167
+
1168
+ # Single cache clear after all layers — calling it per-layer is expensive
1169
+ if torch.cuda.is_available():
1170
+ torch.cuda.empty_cache()
1171
+
1172
+ # ── Per-layer sort: group by layer, within each layer sort by diff desc ─
1173
+ candidates.sort(key=lambda x: (x[1], -x[0]))
1174
+ diff_lookup: dict = {}
1175
+ for diff_val, layer, feat_idx, *_ in candidates:
1176
+ key = (layer, feat_idx)
1177
+ if key not in diff_lookup or diff_val > diff_lookup[key]:
1178
+ diff_lookup[key] = diff_val
1179
+ tok_html, table_html = compare_to_html(candidates, text1, text2, tokens1_str, tokens2_str)
1180
+ return tok_html, table_html, diff_lookup
1181
+
1182
+ except gr.Error:
1183
+ raise
1184
+ except Exception as e:
1185
+ raise gr.Error(f"Comparison failed: {e}")
1186
+
1187
+
1188
+ # ─── CSS ─────────────────────────────────────────────────────────────────────
1189
+
1190
+ CSS = """
1191
+ /* ══════════════════════════════════════════════════════════════════
1192
+ Color tokens — single source of truth for light / dark palettes
1193
+ ══════════════════════════════════════════════════════════════════ */
1194
+ :root {
1195
+ --c-page-bg: #f4f6fb;
1196
+ --c-card-bg: #ffffff;
1197
+ --c-card-border: #e4e7ef;
1198
+ --c-card-shadow: 0 1px 4px rgba(0,0,0,0.06), 0 4px 16px rgba(0,0,0,0.04);
1199
+ --c-header-bg: linear-gradient(135deg,#eff6ff 0%,#e0eaff 55%,#ede9fe 100%);
1200
+ --c-header-border:#c7d2fe;
1201
+ --c-header-text: #1e293b;
1202
+ --c-header-h1: #1e3a8a;
1203
+ --c-header-p: #475569;
1204
+ --c-pill-bg: rgba(37,99,235,0.08);
1205
+ --c-pill-border: rgba(37,99,235,0.22);
1206
+ --c-pill-text: #1e3a8a;
1207
+ --c-chip-bg: #eff4ff;
1208
+ --c-chip-text: #2563eb;
1209
+ --c-btn2-bg: #f8faff;
1210
+ --c-btn2-border: #d0d7e8;
1211
+ --c-btn2-text: #374151;
1212
+ --c-outbox-bg: #fafbff;
1213
+ --c-outbox-text: #1e293b;
1214
+ --c-outbox-border:#e4e7ef;
1215
+ --c-tab-text: #374151;
1216
+ --c-tab-sel: #2563eb;
1217
+ --c-divider: #dde3f0;
1218
+ --c-th-bg: #f0f4ff;
1219
+ --c-th-text: #2563eb;
1220
+ }
1221
+
1222
+ /* Dark mode via OS/browser preference */
1223
+ @media (prefers-color-scheme: dark) {
1224
+ :root {
1225
+ --c-page-bg: #0f172a;
1226
+ --c-card-bg: #1e293b;
1227
+ --c-card-border: #334155;
1228
+ --c-card-shadow: 0 1px 4px rgba(0,0,0,0.40), 0 4px 16px rgba(0,0,0,0.25);
1229
+ --c-header-bg: linear-gradient(135deg,#172554 0%,#1e3a8a 55%,#3b0764 100%);
1230
+ --c-header-border:#1e40af;
1231
+ --c-header-text: #e2e8f0;
1232
+ --c-header-h1: #bfdbfe;
1233
+ --c-header-p: #94a3b8;
1234
+ --c-pill-bg: rgba(96,165,250,0.12);
1235
+ --c-pill-border: rgba(96,165,250,0.30);
1236
+ --c-pill-text: #93c5fd;
1237
+ --c-chip-bg: #172554;
1238
+ --c-chip-text: #93c5fd;
1239
+ --c-btn2-bg: #1e293b;
1240
+ --c-btn2-border: #475569;
1241
+ --c-btn2-text: #e2e8f0;
1242
+ --c-outbox-bg: #0f172a;
1243
+ --c-outbox-text: #e2e8f0;
1244
+ --c-outbox-border:#334155;
1245
+ --c-tab-text: #94a3b8;
1246
+ --c-tab-sel: #60a5fa;
1247
+ --c-divider: #334155;
1248
+ --c-th-bg: #172554;
1249
+ --c-th-text: #93c5fd;
1250
+ }
1251
+ }
1252
+
1253
+ /* Dark mode via Gradio's explicit dark-mode class (toggled manually) */
1254
+ .dark {
1255
+ --c-page-bg: #0f172a;
1256
+ --c-card-bg: #1e293b;
1257
+ --c-card-border: #334155;
1258
+ --c-card-shadow: 0 1px 4px rgba(0,0,0,0.40), 0 4px 16px rgba(0,0,0,0.25);
1259
+ --c-header-bg: linear-gradient(135deg,#172554 0%,#1e3a8a 55%,#3b0764 100%);
1260
+ --c-header-border:#1e40af;
1261
+ --c-header-text: #e2e8f0;
1262
+ --c-header-h1: #bfdbfe;
1263
+ --c-header-p: #94a3b8;
1264
+ --c-pill-bg: rgba(96,165,250,0.12);
1265
+ --c-pill-border: rgba(96,165,250,0.30);
1266
+ --c-pill-text: #93c5fd;
1267
+ --c-chip-bg: #172554;
1268
+ --c-chip-text: #93c5fd;
1269
+ --c-btn2-bg: #1e293b;
1270
+ --c-btn2-border: #475569;
1271
+ --c-btn2-text: #e2e8f0;
1272
+ --c-outbox-bg: #0f172a;
1273
+ --c-outbox-text: #e2e8f0;
1274
+ --c-outbox-border:#334155;
1275
+ --c-tab-text: #94a3b8;
1276
+ --c-tab-sel: #60a5fa;
1277
+ --c-divider: #334155;
1278
+ --c-th-bg: #172554;
1279
+ --c-th-text: #93c5fd;
1280
+ }
1281
+
1282
+ /* ── Page background ── */
1283
+ body, .gradio-container { background: var(--c-page-bg) !important; }
1284
+
1285
+ /* ── Header card ── */
1286
+ .header-card {
1287
+ background: var(--c-header-bg);
1288
+ border-radius: 14px;
1289
+ padding: 22px 28px 18px;
1290
+ margin-bottom: 4px;
1291
+ color: var(--c-header-text);
1292
+ box-shadow: 0 4px 20px rgba(37,99,235,0.10);
1293
+ border: 1px solid var(--c-header-border);
1294
+ }
1295
+ .header-card h1 { margin:0 0 6px; font-size:24px; font-weight:700; letter-spacing:-.3px; color:var(--c-header-h1); }
1296
+ .header-card p { margin:0; font-size:13px; color:var(--c-header-p); }
1297
+ .stat-pill {
1298
+ display:inline-block;
1299
+ background:var(--c-pill-bg);
1300
+ border:1px solid var(--c-pill-border);
1301
+ border-radius:20px;
1302
+ padding:3px 13px;
1303
+ font-size:12px;
1304
+ color:var(--c-pill-text);
1305
+ margin:4px 3px 0;
1306
+ }
1307
+
1308
+ /* ── Panel cards ── */
1309
+ .panel-card {
1310
+ background: var(--c-card-bg) !important;
1311
+ border-radius: 12px !important;
1312
+ box-shadow: var(--c-card-shadow) !important;
1313
+ border: 1px solid var(--c-card-border) !important;
1314
+ padding: 18px !important;
1315
+ }
1316
+ .panel-card > .form { gap: 12px !important; }
1317
+
1318
+ /* ── Section label chips ── */
1319
+ .section-chip {
1320
+ font-size: 11px;
1321
+ font-weight: 700;
1322
+ text-transform: uppercase;
1323
+ letter-spacing: .8px;
1324
+ color: var(--c-chip-text);
1325
+ background: var(--c-chip-bg);
1326
+ border-radius: 6px;
1327
+ padding: 2px 10px;
1328
+ display: inline-block;
1329
+ margin-bottom: 10px;
1330
+ }
1331
+
1332
+ /* ── Buttons ── */
1333
+ .btn-primary {
1334
+ background: linear-gradient(135deg, #2563eb, #6d28d9) !important;
1335
+ border: none !important;
1336
+ border-radius: 8px !important;
1337
+ font-weight: 600 !important;
1338
+ font-size: 14px !important;
1339
+ letter-spacing: .2px !important;
1340
+ box-shadow: 0 2px 10px rgba(37,99,235,0.30) !important;
1341
+ transition: all 0.18s ease !important;
1342
+ color: #fff !important;
1343
+ padding: 10px 0 !important;
1344
+ }
1345
+ .btn-primary:hover {
1346
+ transform: translateY(-1px) !important;
1347
+ box-shadow: 0 5px 18px rgba(37,99,235,0.40) !important;
1348
+ }
1349
+ .btn-secondary {
1350
+ border-radius: 8px !important;
1351
+ font-weight: 500 !important;
1352
+ font-size: 13px !important;
1353
+ border: 1px solid var(--c-btn2-border) !important;
1354
+ background: var(--c-btn2-bg) !important;
1355
+ color: var(--c-btn2-text) !important;
1356
+ transition: all 0.15s ease !important;
1357
+ }
1358
+ .btn-secondary:hover {
1359
+ background: var(--c-chip-bg) !important;
1360
+ border-color: var(--c-tab-sel) !important;
1361
+ }
1362
+
1363
+ /* ── Output boxes ── */
1364
+ .output-box textarea {
1365
+ font-family: ui-monospace, monospace !important;
1366
+ font-size: 13px !important;
1367
+ line-height: 1.7 !important;
1368
+ background: var(--c-outbox-bg) !important;
1369
+ color: var(--c-outbox-text) !important;
1370
+ border-color: var(--c-outbox-border) !important;
1371
+ border-radius: 8px !important;
1372
+ }
1373
+
1374
+ /* ── Dataframe ── */
1375
+ .feature-table table { font-family: ui-monospace, monospace; font-size: 13px; }
1376
+ .feature-table th { background: var(--c-th-bg) !important; color: var(--c-th-text) !important;
1377
+ font-weight: 600; font-size: 12px; text-transform: uppercase; }
1378
+
1379
+ /* ── Tab styling ── */
1380
+ .tab-nav button {
1381
+ font-weight: 600 !important;
1382
+ font-size: 14px !important;
1383
+ border-radius: 8px 8px 0 0 !important;
1384
+ color: var(--c-tab-text) !important;
1385
+ }
1386
+ .tab-nav button.selected {
1387
+ color: var(--c-tab-sel) !important;
1388
+ border-bottom: 2px solid var(--c-tab-sel) !important;
1389
+ }
1390
+
1391
+ /* ── Divider ── */
1392
+ .section-divider {
1393
+ border: none;
1394
+ border-top: 1px dashed var(--c-divider);
1395
+ margin: 6px 0 10px;
1396
+ }
1397
+
1398
+ /* ── Slider label ── */
1399
+ label.svelte-1b6s6sv { font-size: 13px !important; font-weight: 500 !important; }
1400
+ """
1401
+
1402
+ # ─── Build the Gradio interface ───────────────────────────────────────────────
1403
+
1404
+ with gr.Blocks(title="Qwen-Scope Feature Explorer", theme=gr.themes.Soft(), css=CSS) as demo:
1405
+
1406
+ # ── Header ────────────────────────────────────────────────────────────────
1407
+ gr.HTML(
1408
+ '<div class="header-card">'
1409
+ ' <div style="display:flex;align-items:center;gap:8px;margin-bottom:6px;">'
1410
+ ' <img src="https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png" alt="Qwen Logo" style="height:24px;width:auto;">'
1411
+ ' <h1 style="margin:0;">Qwen-Scope Feature Explorer</h1>'
1412
+ ' </div>'
1413
+ f' <p>Interpret {MODEL_NAME_ANALYZING_NOW} via Sparse Autoencoders trained on each residual-stream layer from {MODEL_NAME_SAE_TRAINED_FROM}.</p>'
1414
+ ' <div style="margin-top:10px;">'
1415
+ f' <span class="stat-pill">Model: {MODEL_NAME_ANALYZING_NOW}</span>'
1416
+ f' <span class="stat-pill">SAE trained from: {MODEL_NAME_SAE_TRAINED_FROM}</span>'
1417
+ f' <span class="stat-pill">Layers: {NUM_LAYERS}</span>'
1418
+ f' <span class="stat-pill">SAE width: {SAE_WIDTH:,}</span>'
1419
+ f' <span class="stat-pill">Top-k: {TOP_K}</span>'
1420
+ f' <span class="stat-pill">d_model: {D_MODEL}</span>'
1421
+ ' </div>'
1422
+ '</div>'
1423
+ )
1424
+
1425
+ analysis_state = gr.State(None) # (list[str] tokens, Tensor[seq, sae_width] features)
1426
+ compare_diff_state = gr.State({})
1427
+
1428
+ with gr.Tabs(elem_classes="tab-nav"):
1429
+
1430
+ # ═════════════════════════════════════════════��════════════════════════
1431
+ # Tab 1 — Feature Comparison
1432
+ # ══════════════════════════════════════════════════════════════════════
1433
+ with gr.Tab("⚖️ Feature Comparison"):
1434
+
1435
+ with gr.Row(equal_height=False):
1436
+
1437
+ # ── Left column: inputs + settings + token preview ─────────────
1438
+ with gr.Column(scale=2, min_width=300):
1439
+
1440
+ with gr.Accordion("Examples", open=True) as t3_examples_accordion:
1441
+ with gr.Group(elem_classes="panel-card"):
1442
+ gr.HTML('<span class="section-chip">Examples</span>')
1443
+ t3_text1 = gr.Textbox(
1444
+ label="Example 1",
1445
+ lines=5,
1446
+ placeholder="Paste first text here…",
1447
+ )
1448
+ t3_text2 = gr.Textbox(
1449
+ label="Example 2",
1450
+ lines=5,
1451
+ placeholder="Paste second text here…",
1452
+ )
1453
+
1454
+ with gr.Accordion("Comparison Settings", open=True) as t3_settings_accordion:
1455
+ with gr.Group(elem_classes="panel-card"):
1456
+ gr.HTML('<span class="section-chip">Comparison Settings</span>')
1457
+ with gr.Row():
1458
+ t3_layer_from = gr.Slider(
1459
+ minimum=0, maximum=NUM_LAYERS - 1,
1460
+ value=0, step=1,
1461
+ label="Layer from",
1462
+ scale=1,
1463
+ )
1464
+ t3_layer_to = gr.Slider(
1465
+ minimum=0, maximum=NUM_LAYERS - 1,
1466
+ value=NUM_LAYERS - 1, step=1,
1467
+ label="Layer to",
1468
+ scale=1,
1469
+ )
1470
+ t3_topk = gr.Number(
1471
+ value=5, precision=0,
1472
+ label="Top-K results",
1473
+ info="Number of (layer, feature) pairs to surface.",
1474
+ )
1475
+ with gr.Accordion("Advanced options", open=False):
1476
+ t3_skip_first = gr.Checkbox(
1477
+ label="Exclude first token",
1478
+ value=False,
1479
+ info="Skip position 0 when computing mean activations.",
1480
+ )
1481
+ t3_remove_common_toks = gr.Checkbox(
1482
+ label="Remove common tokens",
1483
+ value=False,
1484
+ info="Exclude positions whose token ID appears in both examples.",
1485
+ )
1486
+ t3_remove_prefix = gr.Checkbox(
1487
+ label="Remove common prefix",
1488
+ value=False,
1489
+ info="Exclude the longest token-level prefix shared by both examples.",
1490
+ )
1491
+ t3_run = gr.Button(
1492
+ "⚖️ Compare Features",
1493
+ variant="primary",
1494
+ elem_classes="btn-primary",
1495
+ )
1496
+
1497
+ with gr.Accordion("Features", open=True) as t3_features_accordion:
1498
+ with gr.Group(elem_classes="panel-card"):
1499
+ gr.HTML(
1500
+ '<span class="section-chip">Feature Comparison</span>'
1501
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1502
+ 'top-K features per layer · ranked by |rate(Ex1) − rate(Ex2)|'
1503
+ ' where rate = fraction of token positions where the feature fires · grouped by layer'
1504
+ '</span>'
1505
+ )
1506
+ t3_out = gr.HTML(
1507
+ value=(
1508
+ '<div style="min-height:80px;display:flex;align-items:center;'
1509
+ 'justify-content:center;color:#bbb;font-size:13px;">'
1510
+ 'Enter two examples and click Compare.</div>'
1511
+ )
1512
+ )
1513
+
1514
+ # ── Right column: token activations ──────────────────────────���
1515
+ with gr.Column(scale=3, min_width=380):
1516
+ with gr.Group(elem_classes="panel-card"):
1517
+ gr.HTML(
1518
+ '<span class="section-chip">Token Activations</span>'
1519
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1520
+ 'hover a feature row on the left to highlight activations'
1521
+ '</span>'
1522
+ )
1523
+ t3_tok_html = gr.HTML(
1524
+ value=(
1525
+ '<div style="min-height:60px;display:flex;align-items:center;'
1526
+ 'justify-content:center;color:#bbb;font-size:13px;padding:8px;">'
1527
+ 'Run Compare to see token activations here.</div>'
1528
+ )
1529
+ )
1530
+
1531
+ # ══════════════════════════════════════════════════════════════════════
1532
+ # Tab 2 — Feature Steering
1533
+ # ══════════════════════════════════════════════════════════════════════
1534
+ with gr.Tab("🎛️ Feature Steering"):
1535
+
1536
+ with gr.Row(equal_height=False):
1537
+
1538
+ # ── Left column: prompt + steering controls ────────────────
1539
+ with gr.Column(scale=2, min_width=280):
1540
+ with gr.Group(elem_classes="panel-card"):
1541
+ gr.HTML('<span class="section-chip">Prompt</span>')
1542
+ t2_prompt = gr.Textbox(
1543
+ label=None,
1544
+ lines=5,
1545
+ placeholder="Enter a generation prompt…",
1546
+ show_label=False,
1547
+ )
1548
+
1549
+ t2_apply_think = gr.Checkbox(
1550
+ label="Apply default thinking template",
1551
+ value=False,
1552
+ info=(
1553
+ "Wrap the prompt in the ChatML format with thinking enabled "
1554
+ "(assistant prefill starts with <think>)."
1555
+ ),
1556
+ )
1557
+ t2_apply_nothink = gr.Checkbox(
1558
+ label="Apply default no-thinking template",
1559
+ value=False,
1560
+ info=(
1561
+ "Wrap the prompt in the ChatML format with thinking disabled "
1562
+ "(assistant prefill starts with <think>\\n\\n</think>)."
1563
+ ),
1564
+ )
1565
+ t2_template_info = gr.HTML(visible=False, value="")
1566
+
1567
+ gr.HTML('<span class="section-chip">Token Position Preview</span>'
1568
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1569
+ 'amber = steered &nbsp;·&nbsp; updates as you type'
1570
+ '</span>')
1571
+
1572
+ t2_pos_preview = gr.HTML(
1573
+ value=(
1574
+ '<div style="padding:10px;color:#bbb;font-size:13px;">'
1575
+ 'Enter a prompt above to preview steered positions.</div>'
1576
+ )
1577
+ )
1578
+
1579
+ with gr.Group(elem_classes="panel-card"):
1580
+ gr.HTML('<span class="section-chip">Steering Parameters</span>')
1581
+
1582
+ with gr.Row():
1583
+ t2_layer = gr.Slider(
1584
+ minimum=0, maximum=NUM_LAYERS - 1,
1585
+ value=10, step=1,
1586
+ label="Layer",
1587
+ scale=3,
1588
+ )
1589
+ t2_feat = gr.Number(
1590
+ value=0, precision=0,
1591
+ label="Feature index",
1592
+ info=f"0 – {SAE_WIDTH - 1}",
1593
+ scale=2,
1594
+ )
1595
+
1596
+ t2_pos = gr.Textbox(
1597
+ label="Token positions to steer",
1598
+ value="all",
1599
+ placeholder="all | 0,1,5 | 3-7 | 0,2,5-8",
1600
+ info=(
1601
+ "all → every token | "
1602
+ "0,1,5 → individual positions | "
1603
+ "3-7 → inclusive range | "
1604
+ "combinations e.g. 0,2,5-8"
1605
+ ),
1606
+ )
1607
+ t2_steer_output_only = gr.Checkbox(
1608
+ label="Also steer generated tokens",
1609
+ value=True,
1610
+ info=(
1611
+ "When enabled, every generated token is steered in addition to "
1612
+ "whatever the positions field specifies for the prompt."
1613
+ ),
1614
+ )
1615
+
1616
+ gr.HTML('<span class="section-chip">Steering Strength</span>')
1617
+ t2_steer_mode = gr.Radio(
1618
+ choices=["Light", "Medium", "Strong", "Custom"],
1619
+ value="Light",
1620
+ label=None,
1621
+ show_label=False,
1622
+ info=(
1623
+ "Calibrated to the most different feature found in "
1624
+ "Feature Comparison. Run that tab first."
1625
+ ),
1626
+ )
1627
+ t2_custom_strength = gr.Number(
1628
+ value=5.0,
1629
+ label="Custom strength",
1630
+ info="Direct steering magnitude (used when Custom is selected above).",
1631
+ visible=False,
1632
+ precision=2,
1633
+ )
1634
+ t2_steer_info = gr.HTML(
1635
+ value=(
1636
+ '<div style="font-size:11px;color:#888;padding:4px 6px;'
1637
+ 'background:#f8faff;border-radius:5px;">'
1638
+ 'Light ≈ 5.0 · Medium ≈ 20.0 · Strong ≈ 100.0<br>'
1639
+ '<span style="color:#bbb;">Run Feature Comparison to calibrate.</span>'
1640
+ '</div>'
1641
+ )
1642
+ )
1643
+
1644
+ gr.HTML('<hr class="section-divider">')
1645
+ with gr.Accordion("Sampling options", open=False):
1646
+ t2_maxtok = gr.Slider(
1647
+ minimum=20, maximum=300,
1648
+ value=100, step=10,
1649
+ label="Max new tokens",
1650
+ )
1651
+ t2_greedy = gr.Checkbox(
1652
+ label="Greedy decoding",
1653
+ value=True,
1654
+ info="When enabled, all sampling parameters below are ignored.",
1655
+ )
1656
+ with gr.Row():
1657
+ t2_temperature = gr.Slider(
1658
+ minimum=0.01, maximum=2.0,
1659
+ value=GEN_TEMPERATURE, step=0.01,
1660
+ label="Temperature",
1661
+ interactive=GEN_DO_SAMPLE,
1662
+ )
1663
+ t2_top_p = gr.Slider(
1664
+ minimum=0.0, maximum=1.0,
1665
+ value=GEN_TOP_P, step=0.01,
1666
+ label="Top-p (nucleus)",
1667
+ interactive=GEN_DO_SAMPLE,
1668
+ )
1669
+ with gr.Row():
1670
+ t2_top_k_tok = gr.Slider(
1671
+ minimum=0, maximum=200,
1672
+ value=GEN_TOP_K, step=1,
1673
+ label="Top-k (tokens)",
1674
+ info="0 = disabled",
1675
+ interactive=GEN_DO_SAMPLE,
1676
+ )
1677
+ t2_rep_penalty = gr.Slider(
1678
+ minimum=1.0, maximum=3.0,
1679
+ value=GEN_REP_PENALTY, step=0.05,
1680
+ label="Repetition penalty",
1681
+ info="1.0 = no penalty",
1682
+ interactive=GEN_DO_SAMPLE,
1683
+ )
1684
+
1685
+ t2_run = gr.Button(
1686
+ "▶ Generate Both Outputs",
1687
+ variant="primary",
1688
+ elem_classes="btn-primary",
1689
+ )
1690
+
1691
+ # ── Right column: outputs ──────────────────────────────────
1692
+ with gr.Column(scale=3, min_width=380):
1693
+
1694
+ with gr.Group(elem_classes="panel-card"):
1695
+ gr.HTML(
1696
+ '<span class="section-chip">Original Output</span>'
1697
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1698
+ 'No steering applied</span>'
1699
+ )
1700
+ t2_orig = gr.Textbox(
1701
+ label=None, lines=7,
1702
+ interactive=False,
1703
+ show_label=False,
1704
+ placeholder="Original generation will appear here…",
1705
+ elem_classes="output-box",
1706
+ )
1707
+ gr.HTML(
1708
+ '<span class="section-chip" style="margin-top:10px;'
1709
+ 'display:inline-block;">Token Probabilities</span>'
1710
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1711
+ 'blue intensity = confidence &nbsp;·&nbsp; hover = top-k</span>'
1712
+ )
1713
+ t2_orig_probs = gr.HTML(
1714
+ value='<div style="padding:10px;color:#bbb;font-size:13px;">'
1715
+ 'Run generation to see token probabilities.</div>'
1716
+ )
1717
+
1718
+ with gr.Group(elem_classes="panel-card"):
1719
+ gr.HTML(
1720
+ '<span class="section-chip" style="background:#fef3f2;color:#dc2626;">'
1721
+ 'Steered Output</span>'
1722
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1723
+ 'With SAE feature injection</span>'
1724
+ )
1725
+ t2_steered = gr.Textbox(
1726
+ label=None, lines=7,
1727
+ interactive=False,
1728
+ show_label=False,
1729
+ placeholder="Steered generation will appear here…",
1730
+ elem_classes="output-box",
1731
+ )
1732
+ gr.HTML(
1733
+ '<span class="section-chip" style="background:#fef3f2;color:#dc2626;'
1734
+ 'margin-top:10px;display:inline-block;">Token Probabilities</span>'
1735
+ '<span style="font-size:12px;color:#888;margin-left:8px;">'
1736
+ 'red intensity = confidence &nbsp;·&nbsp; hover = top-k</span>'
1737
+ )
1738
+ t2_steer_probs = gr.HTML(
1739
+ value='<div style="padding:10px;color:#bbb;font-size:13px;">'
1740
+ 'Run generation to see token probabilities.</div>'
1741
+ )
1742
+
1743
+ t2_run.click(
1744
+ cb_generate,
1745
+ inputs=[t2_prompt, t2_layer, t2_feat, t2_pos, t2_steer_mode, compare_diff_state,
1746
+ t2_steer_output_only, t2_maxtok,
1747
+ t2_greedy, t2_top_k_tok, t2_top_p, t2_rep_penalty,
1748
+ t2_temperature, t2_custom_strength, t2_apply_think, t2_apply_nothink],
1749
+ outputs=[t2_orig, t2_steered, t2_orig_probs, t2_steer_probs],
1750
+ )
1751
+ t3_run.click(
1752
+ cb_compare,
1753
+ inputs=[t3_text1, t3_text2, t3_layer_from, t3_layer_to, t3_topk,
1754
+ t3_skip_first, t3_remove_common_toks, t3_remove_prefix],
1755
+ outputs=[t3_tok_html, t3_out, compare_diff_state],
1756
+ ).then(
1757
+ fn=lambda: [gr.update(open=False), gr.update(open=False)],
1758
+ inputs=None,
1759
+ outputs=[t3_examples_accordion, t3_settings_accordion],
1760
+ )
1761
+ _sampling_controls = [
1762
+ t2_temperature, t2_top_p, t2_top_k_tok, t2_rep_penalty
1763
+ ]
1764
+ t2_greedy.change(
1765
+ fn=lambda g: [gr.update(interactive=not g)] * 4,
1766
+ inputs=[t2_greedy],
1767
+ outputs=_sampling_controls,
1768
+ )
1769
+ t2_prompt.change(
1770
+ cb_update_steer_preview,
1771
+ inputs=[t2_prompt, t2_pos, t2_apply_think, t2_apply_nothink],
1772
+ outputs=[t2_pos_preview],
1773
+ )
1774
+ t2_pos.change(
1775
+ cb_update_steer_preview,
1776
+ inputs=[t2_prompt, t2_pos, t2_apply_think, t2_apply_nothink],
1777
+ outputs=[t2_pos_preview],
1778
+ )
1779
+
1780
+ def _update_steer_info(mode: str, diff_lookup, layer, feat_idx):
1781
+ if mode == "Custom":
1782
+ return (
1783
+ '<div style="font-size:11px;color:#555;padding:4px 6px;'
1784
+ 'background:#f8faff;border-radius:5px;">'
1785
+ 'Enter a custom steering strength value above.'
1786
+ '</div>'
1787
+ )
1788
+ d = 0.0
1789
+ source_note = '<span style="color:#bbb;">Run Feature Comparison to calibrate.</span>'
1790
+ if diff_lookup and isinstance(diff_lookup, dict):
1791
+ key = (int(layer), int(feat_idx))
1792
+ if key in diff_lookup:
1793
+ d = float(diff_lookup[key])
1794
+ source_note = (
1795
+ f'<span style="color:#16a34a;">feature #{int(feat_idx)} '
1796
+ f'@ layer {int(layer)} · diff = {d:.3f}</span>'
1797
+ )
1798
+ else:
1799
+ d = float(max(diff_lookup.values(), default=0.0))
1800
+ source_note = (
1801
+ f'<span style="color:#64748b;">feature not in compare results — '
1802
+ f'using global max diff = {d:.3f}</span>'
1803
+ )
1804
+ if d <= 0:
1805
+ vals = {"Light": 5.0, "Medium": 20.0, "Strong": 100.0}
1806
+ else:
1807
+ vals = {
1808
+ "Light": round(d * 0.5, 2),
1809
+ "Medium": round(d * 2.0, 2),
1810
+ "Strong": round(d * 10.0, 2),
1811
+ }
1812
+ return (
1813
+ f'<div style="font-size:11px;color:#555;padding:4px 6px;'
1814
+ f'background:#f8faff;border-radius:5px;">'
1815
+ f'Light ≈ {vals["Light"]} · Medium ≈ {vals["Medium"]} · Strong ≈ {vals["Strong"]}<br>'
1816
+ + source_note + '</div>'
1817
+ )
1818
+
1819
+ _steer_info_inputs = [t2_steer_mode, compare_diff_state, t2_layer, t2_feat]
1820
+ for _trigger in [t2_steer_mode.change, compare_diff_state.change,
1821
+ t2_layer.change, t2_feat.change]:
1822
+ _trigger(
1823
+ fn=_update_steer_info,
1824
+ inputs=_steer_info_inputs,
1825
+ outputs=[t2_steer_info],
1826
+ )
1827
+
1828
+ # Show/hide custom strength input depending on radio selection
1829
+ t2_steer_mode.change(
1830
+ fn=lambda m: gr.update(visible=(m == "Custom")),
1831
+ inputs=[t2_steer_mode],
1832
+ outputs=[t2_custom_strength],
1833
+ )
1834
+
1835
+ # ── Template toggle: mutual exclusion + info panel + preview refresh ─
1836
+ _THINK_INFO_HTML = (
1837
+ '<div style="font-size:11px;color:#555;padding:6px 10px;'
1838
+ 'background:#eff6ff;border:1px solid #bfdbfe;border-radius:6px;'
1839
+ 'font-family:ui-monospace,monospace;white-space:pre-wrap;line-height:1.7;">'
1840
+ '&lt;|im_start|&gt;user\n'
1841
+ '&#123;your prompt&#125;&lt;|im_end|&gt;\n'
1842
+ '&lt;|im_start|&gt;assistant\n'
1843
+ '&lt;think&gt;\n'
1844
+ '</div>'
1845
+ )
1846
+ _NOTHINK_INFO_HTML = (
1847
+ '<div style="font-size:11px;color:#555;padding:6px 10px;'
1848
+ 'background:#f0fdf4;border:1px solid #bbf7d0;border-radius:6px;'
1849
+ 'font-family:ui-monospace,monospace;white-space:pre-wrap;line-height:1.7;">'
1850
+ '&lt;|im_start|&gt;user\n'
1851
+ '&#123;your prompt&#125;&lt;|im_end|&gt;\n'
1852
+ '&lt;|im_start|&gt;assistant\n'
1853
+ '&lt;think&gt;\n\n&lt;/think&gt;\n\n'
1854
+ '</div>'
1855
+ )
1856
+
1857
+ def _on_think_change(think_val, nothink_val, prompt, pos_str):
1858
+ if think_val:
1859
+ # Just checked: uncheck nothink, show think format, refresh preview
1860
+ return (gr.update(value=False),
1861
+ gr.update(visible=True, value=_THINK_INFO_HTML),
1862
+ cb_update_steer_preview(prompt, pos_str, True, False))
1863
+ elif nothink_val:
1864
+ # Unchecked by mutual exclusion — nothink is active; leave info+preview alone
1865
+ return gr.update(), gr.update(), gr.update()
1866
+ else:
1867
+ # Manually unchecked with nothing active — reset to raw
1868
+ return (gr.update(),
1869
+ gr.update(visible=False),
1870
+ cb_update_steer_preview(prompt, pos_str, False, False))
1871
+
1872
+ def _on_nothink_change(nothink_val, think_val, prompt, pos_str):
1873
+ if nothink_val:
1874
+ # Just checked: uncheck think, show nothink format, refresh preview
1875
+ return (gr.update(value=False),
1876
+ gr.update(visible=True, value=_NOTHINK_INFO_HTML),
1877
+ cb_update_steer_preview(prompt, pos_str, False, True))
1878
+ elif think_val:
1879
+ # Unchecked by mutual exclusion — think is active; leave info+preview alone
1880
+ return gr.update(), gr.update(), gr.update()
1881
+ else:
1882
+ # Manually unchecked with nothing active — reset to raw
1883
+ return (gr.update(),
1884
+ gr.update(visible=False),
1885
+ cb_update_steer_preview(prompt, pos_str, False, False))
1886
+
1887
+ t2_apply_think.change(
1888
+ fn=_on_think_change,
1889
+ inputs=[t2_apply_think, t2_apply_nothink, t2_prompt, t2_pos],
1890
+ outputs=[t2_apply_nothink, t2_template_info, t2_pos_preview],
1891
+ )
1892
+ t2_apply_nothink.change(
1893
+ fn=_on_nothink_change,
1894
+ inputs=[t2_apply_nothink, t2_apply_think, t2_prompt, t2_pos],
1895
+ outputs=[t2_apply_think, t2_template_info, t2_pos_preview],
1896
+ )
1897
+
1898
+
1899
+ if __name__ == '__main__':
1900
+ # ZeroGPU: model is loaded lazily inside @spaces.GPU decorated functions,
1901
+ # so we do NOT pre-load here (Space boots on CPU, GPU is allocated on demand).
1902
+ print("Starting Gradio server on ZeroGPU Space…")
1903
+ demo.queue(max_size=4)
1904
+ demo.launch(
1905
+ server_name="0.0.0.0",
1906
+ server_port=int(os.environ.get("PORT", PORT)),
1907
+ share=False,
1908
+ strict_cors=False,
1909
+ show_error=True,
1910
+ ssr_mode=False,
1911
+ )