Spaces:
Running
Running
| """FastAPI backend for the Qwen-Scope HF Space deployment. | |
| Locked to Qwen3-1.7B-Base + the W32K-L0_50 SAE so it fits inside a | |
| free-tier HF Space (CPU, ~16GB RAM). Layer is still selectable. | |
| """ | |
| from __future__ import annotations | |
| import gc | |
| import json | |
| import os | |
| import threading | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from qwen_scope_steer import SAE, capture_residual, steer | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float32 # bf16 on CPU is slow + flaky on free-tier hardware | |
| POSITIONS_DIR = Path(os.environ.get( | |
| "POSITIONS_DIR", | |
| str(Path(__file__).parent / "feature_positions"), | |
| )) | |
| POSITIONS_DIR.mkdir(exist_ok=True, parents=True) | |
| # --------------------------------------------------------------------------- | |
| # Catalog of supported model + SAE pairs. | |
| # Verified against the Qwen org HF listing. For Qwen3.6 (no native SAE yet) | |
| # we point at the Qwen3.5 SAE that matches dimensions; this is a best-effort | |
| # fallback flagged as transferred=True in the response. | |
| # --------------------------------------------------------------------------- | |
| MODEL_CATALOG = [ | |
| { | |
| "model": "Qwen/Qwen3-1.7B-Base", | |
| "sae_repo": "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50", | |
| "default_layer": 14, "n_layers": 28, "n_features": 32768, | |
| "approx_size_gb": 3.4, "k": 50, "transferred": False, | |
| }, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # State and locks | |
| # --------------------------------------------------------------------------- | |
| state: dict = {} | |
| load_lock = threading.Lock() | |
| def _find_decoder_layers(model): | |
| """Return (layers_module_list, dotted_path) for any qwen3 / qwen3_5 model. | |
| Handles: | |
| * model.model.layers (standard Qwen3*ForCausalLM) | |
| * model.language_model.model.layers (multimodal Qwen3_5ForConditionalGeneration) | |
| """ | |
| for path in (("model", "model", "layers"), | |
| ("model", "layers"), | |
| ("language_model", "model", "layers"), | |
| ("model", "language_model", "model", "layers")): | |
| obj = model | |
| ok = True | |
| for p in path: | |
| if not hasattr(obj, p): | |
| ok = False; break | |
| obj = getattr(obj, p) | |
| if ok and hasattr(obj, "__len__") and len(obj) > 0: | |
| return obj, ".".join(path) | |
| raise RuntimeError(f"could not locate decoder layers on " | |
| f"{type(model).__name__}") | |
| # --------------------------------------------------------------------------- | |
| # Position computation (cached per SAE). | |
| # Uses TruncatedSVD via numpy power-iteration for the 80K feature SAE, | |
| # economy SVD for smaller ones. Good enough for visualization layout. | |
| # --------------------------------------------------------------------------- | |
| def _safe_filename(s: str) -> str: | |
| return s.replace("/", "__") | |
| def _positions_path(sae_repo: str, layer: int | None = None) -> Path: | |
| if layer is None: | |
| return POSITIONS_DIR / f"{_safe_filename(sae_repo)}.json" | |
| return POSITIONS_DIR / f"{_safe_filename(sae_repo)}__L{layer}.json" | |
| def compute_positions(W_enc: torch.Tensor) -> list[list[float]]: | |
| X = W_enc.detach().to("cpu", torch.float32).numpy() # (n_features, d_model) | |
| X = X - X.mean(axis=0, keepdims=True) | |
| n, d = X.shape | |
| if n * d <= 32768 * 4096: | |
| # Economy SVD is fine for the smaller SAEs. | |
| _, _, Vt = np.linalg.svd(X, full_matrices=False) | |
| pos = X @ Vt[:3].T | |
| else: | |
| # Randomized SVD for very large SAEs (e.g. 80K * 5120). | |
| rng = np.random.default_rng(0) | |
| Q = rng.standard_normal((d, 8)).astype(np.float32) | |
| for _ in range(3): # power iterations | |
| Q = X.T @ (X @ Q) | |
| Q, _ = np.linalg.qr(Q) | |
| Y = X @ Q # (n, 8) | |
| _, _, Vt2 = np.linalg.svd(Y, full_matrices=False) | |
| pos = Y @ Vt2[:3].T | |
| pos = pos / max(abs(pos.min()), abs(pos.max())) | |
| return pos.tolist() | |
| def load_or_compute_positions(W_enc: torch.Tensor, sae_repo: str, | |
| layer: int | None = None) -> list[list[float]]: | |
| # Try layer-specific cache first; fall back to legacy SAE-repo-only cache | |
| # so existing files don't go stale. | |
| p_layer = _positions_path(sae_repo, layer) | |
| p_legacy = _positions_path(sae_repo) | |
| for p in (p_layer, p_legacy): | |
| if p.exists(): | |
| try: | |
| return json.loads(p.read_text())["positions"] | |
| except Exception: | |
| pass | |
| pos = compute_positions(W_enc) | |
| p_layer.write_text(json.dumps({"positions": pos})) | |
| return pos | |
| # --------------------------------------------------------------------------- | |
| # Model + SAE loading (called both at startup and on /load_model) | |
| # --------------------------------------------------------------------------- | |
| def _free_current_state(): | |
| """Release the currently loaded model + SAE so a new one can fit.""" | |
| for k in ("model", "tokenizer", "sae", "layers"): | |
| if k in state: | |
| del state[k] | |
| gc.collect() | |
| if hasattr(torch, "mps") and torch.backends.mps.is_available(): | |
| try: | |
| torch.mps.empty_cache() | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def _catalog_entry(model_name: str, sae_repo: str | None) -> dict: | |
| """Find the catalog row that matches model_name (and optionally sae_repo).""" | |
| for row in MODEL_CATALOG: | |
| if row["model"] == model_name and (sae_repo is None or row["sae_repo"] == sae_repo): | |
| return row | |
| raise HTTPException(status_code=400, | |
| detail=f"unknown model/sae combination: {model_name} / {sae_repo}") | |
| def load_state(model_name: str, sae_repo: str | None = None, | |
| layer: int | None = None, k: int = 50) -> dict: | |
| """Replace the loaded model+SAE+layer with the requested one.""" | |
| entry = _catalog_entry(model_name, sae_repo) | |
| sae_repo = entry["sae_repo"] | |
| layer = entry["default_layer"] if layer is None else int(layer) | |
| k = entry.get("k", k) | |
| print(f"[load] {model_name} ({entry['approx_size_gb']:.0f}GB) " | |
| f"+ SAE {sae_repo} layer {layer} on {DEVICE}") | |
| _free_current_state() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, dtype=DTYPE, device_map=DEVICE, | |
| ) | |
| model.eval() | |
| layers, layers_path = _find_decoder_layers(model) | |
| n_layers = len(layers) | |
| if not (0 <= layer < n_layers): | |
| layer = min(max(0, layer), n_layers - 1) | |
| print(f"[load] model loaded: {type(model).__name__}, layers at " | |
| f"'{layers_path}', n={n_layers}") | |
| sae = SAE.from_repo(sae_repo, layer=layer, k=k, device=DEVICE, dtype=DTYPE) | |
| print(f"[load] SAE loaded: n_features={sae.n_features}, d_model={sae.d_model}") | |
| print("[load] computing/loading 3D feature positions") | |
| positions = load_or_compute_positions(sae.W_enc, sae_repo, layer) | |
| _sae_cache_put(sae_repo, layer, sae) | |
| state.update( | |
| model=model, tokenizer=tokenizer, sae=sae, | |
| layers=layers, layers_path=layers_path, | |
| positions=positions, n_layers=n_layers, | |
| current_model=model_name, current_sae=sae_repo, | |
| current_layer=layer, current_k=k, | |
| catalog_entry=entry, | |
| ) | |
| print("[load] ready") | |
| return state | |
| # --------------------------------------------------------------------------- | |
| # Hook helpers — work against state["layers"] not model.model.layers | |
| # --------------------------------------------------------------------------- | |
| import contextlib | |
| def _capture_at(layer_module): | |
| bucket = {} | |
| def hook(_m, _i, out): | |
| h = out[0] if isinstance(out, tuple) else out | |
| bucket["h"] = h.detach() | |
| return out | |
| handle = layer_module.register_forward_hook(hook) | |
| try: | |
| yield bucket | |
| finally: | |
| handle.remove() | |
| def _steer_at(layer_module, direction, alpha, *, | |
| positions=None, output_only=False, prompt_len=None): | |
| """Hook adds α·direction to layer residual, with position/decode controls. | |
| positions : None or "all" → every token; list[int] → only those absolute | |
| token indices (works across prefill + decode). | |
| output_only : if True, only steer during decode (skip prefill entirely). | |
| prompt_len : length of the prompt; needed to map decode-step counter | |
| to absolute position when positions is a list. | |
| """ | |
| direction = direction.detach() | |
| counter = [0] | |
| pos_set = set(positions) if isinstance(positions, (list, set)) else None | |
| def hook(_m, _i, out): | |
| h = out[0] if isinstance(out, tuple) else out | |
| d = direction.to(device=h.device, dtype=h.dtype) | |
| cur = counter[0] | |
| counter[0] += 1 | |
| new_h = h | |
| is_prefill = (cur == 0) | |
| if is_prefill: | |
| seq = h.shape[1] | |
| if output_only: | |
| pass # leave prompt untouched | |
| elif pos_set is None: | |
| new_h = h + alpha * d | |
| else: | |
| new_h = h.clone() | |
| for p in pos_set: | |
| if 0 <= p < seq: | |
| new_h[:, p, :] = new_h[:, p, :] + alpha * d | |
| else: | |
| # Decode step — h is [batch, 1, hidden] (one new token) | |
| cur_pos = (prompt_len or 0) + cur - 1 | |
| if pos_set is None or output_only or (cur_pos in pos_set): | |
| new_h = h + alpha * d | |
| return (new_h, *out[1:]) if isinstance(out, tuple) else new_h | |
| handle = layer_module.register_forward_hook(hook) | |
| try: | |
| yield | |
| finally: | |
| handle.remove() | |
| def _parse_positions(s: str | None): | |
| """Parse '3', '3-7', '0,2,5-8', 'all', or None into a position spec. | |
| Returns 'all' or a list[int] (or None if input is empty/None).""" | |
| if s is None or not str(s).strip(): | |
| return None | |
| s = str(s).strip().lower() | |
| if s == "all": | |
| return "all" | |
| out: list[int] = [] | |
| for part in s.split(","): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| if "-" in part: | |
| try: | |
| lo, hi = part.split("-", 1) | |
| out.extend(range(int(lo), int(hi) + 1)) | |
| except ValueError: | |
| continue | |
| else: | |
| try: | |
| out.append(int(part)) | |
| except ValueError: | |
| continue | |
| return sorted(set(out)) if out else None | |
| def _hook_stack(layer_module, sae, specs, prompt_len=None): | |
| from contextlib import ExitStack | |
| stack = ExitStack() | |
| for s in specs: | |
| d = sae.steering_vector(s.id) | |
| positions = _parse_positions(getattr(s, "positions", None)) | |
| output_only = bool(getattr(s, "output_only", False)) | |
| # "all" or None both mean "every position" inside _steer_at — pass None. | |
| eff_positions = None if (positions is None or positions == "all") else positions | |
| stack.enter_context(_steer_at( | |
| layer_module, d, s.alpha, | |
| positions=eff_positions, | |
| output_only=output_only, | |
| prompt_len=prompt_len, | |
| )) | |
| return stack | |
| # --------------------------------------------------------------------------- | |
| # Lifespan + app | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| # Default startup: small model so the demo is interactive immediately. | |
| load_state("Qwen/Qwen3-1.7B-Base") | |
| yield | |
| state.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_methods=["*"], allow_headers=["*"]) | |
| # --------------------------------------------------------------------------- | |
| # Request models | |
| # --------------------------------------------------------------------------- | |
| class EncodeRequest(BaseModel): | |
| prompt: str | |
| top_n: int = 20 | |
| class SteerSpec(BaseModel): | |
| id: int | |
| alpha: float | |
| positions: str | None = None # "all" | "3-7" | "0,2,5" | None (= all) | |
| output_only: bool = False # if True, steer only during decode, not prompt | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| steering: list[SteerSpec] = [] | |
| max_new_tokens: int = 40 | |
| return_probs: bool = False # if True, return per-token softmax + top-K candidates | |
| topk_display: int = 8 # number of candidate tokens to expose per step | |
| class LoadModelRequest(BaseModel): | |
| model: str | |
| sae_repo: str | None = None | |
| layer: int | None = None | |
| class SetLayerRequest(BaseModel): | |
| layer: int | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| def index(): | |
| return FileResponse(Path(__file__).parent / "index.html") | |
| def health(): | |
| sae = state.get("sae") | |
| return { | |
| "ok": True, | |
| "model": state.get("current_model"), | |
| "sae": state.get("current_sae"), | |
| "layer": state.get("current_layer"), | |
| "device": DEVICE, | |
| "dtype": str(DTYPE).replace("torch.", ""), | |
| "n_features": sae.n_features if sae else None, | |
| "n_layers": state.get("n_layers"), | |
| "transferred": state.get("catalog_entry", {}).get("transferred", False), | |
| "note": state.get("catalog_entry", {}).get("note", ""), | |
| } | |
| def list_models(): | |
| return {"models": MODEL_CATALOG} | |
| def load_model(req: LoadModelRequest): | |
| with load_lock: | |
| try: | |
| load_state(req.model, req.sae_repo, req.layer) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"load failed: {e}") | |
| sae = state["sae"] | |
| return { | |
| "ok": True, | |
| "model": state["current_model"], | |
| "sae": state["current_sae"], | |
| "layer": state["current_layer"], | |
| "n_features": sae.n_features, | |
| "n_layers": state["n_layers"], | |
| "transferred": state["catalog_entry"].get("transferred", False), | |
| "note": state["catalog_entry"].get("note", ""), | |
| "positions": state["positions"], | |
| } | |
| # In-memory LRU cache of recently-used SAE checkpoints, keyed by | |
| # (sae_repo, layer). Each SAE for the 1.7B model is ~537 MB on disk and | |
| # similar in RAM at fp32; for the 27B SAE it's ~3.3 GB. Cap conservatively. | |
| _sae_lru: "OrderedDict[tuple[str,int], SAE]" = None # initialized lazily | |
| SAE_LRU_MAX = 6 | |
| def _sae_cache_get(sae_repo: str, layer: int): | |
| global _sae_lru | |
| if _sae_lru is None: | |
| from collections import OrderedDict | |
| _sae_lru = OrderedDict() | |
| key = (sae_repo, layer) | |
| if key in _sae_lru: | |
| _sae_lru.move_to_end(key) | |
| return _sae_lru[key] | |
| return None | |
| def _sae_cache_put(sae_repo: str, layer: int, sae: SAE): | |
| global _sae_lru | |
| if _sae_lru is None: | |
| from collections import OrderedDict | |
| _sae_lru = OrderedDict() | |
| key = (sae_repo, layer) | |
| _sae_lru[key] = sae | |
| _sae_lru.move_to_end(key) | |
| while len(_sae_lru) > SAE_LRU_MAX: | |
| _sae_lru.popitem(last=False) | |
| def set_layer(req: SetLayerRequest): | |
| """Hot-swap the active SAE to a different layer of the same SAE repo. | |
| Keeps the model loaded; just downloads (or fetches from cache) the new | |
| layer's SAE checkpoint. Recomputes 3D positions for the new SAE | |
| (cached on disk per SAE-repo+layer). | |
| """ | |
| if "model" not in state: | |
| raise HTTPException(status_code=400, detail="no model loaded") | |
| n_layers = state["n_layers"] | |
| layer = int(req.layer) | |
| if not (0 <= layer < n_layers): | |
| raise HTTPException(status_code=400, | |
| detail=f"layer must be in [0, {n_layers-1}]") | |
| sae_repo = state["current_sae"] | |
| if layer == state["current_layer"]: | |
| return {"ok": True, "unchanged": True, | |
| "layer": layer, "n_features": state["sae"].n_features, | |
| "positions": state["positions"]} | |
| with load_lock: | |
| # 1. SAE itself — try LRU first | |
| cached = _sae_cache_get(sae_repo, layer) | |
| if cached is not None: | |
| sae = cached | |
| print(f"[layer-swap] SAE {sae_repo} layer {layer} from LRU cache") | |
| else: | |
| print(f"[layer-swap] loading SAE {sae_repo} layer {layer}") | |
| k = state["catalog_entry"].get("k", 50) | |
| sae = SAE.from_repo(sae_repo, layer=layer, k=k, | |
| device=DEVICE, dtype=DTYPE) | |
| _sae_cache_put(sae_repo, layer, sae) | |
| # 2. Positions — per-layer cache file on disk | |
| positions_key = f"{sae_repo}__L{layer}" | |
| p = POSITIONS_DIR / f"{_safe_filename(positions_key)}.json" | |
| if p.exists(): | |
| try: | |
| positions = json.loads(p.read_text())["positions"] | |
| except Exception: | |
| positions = compute_positions(sae.W_enc) | |
| p.write_text(json.dumps({"positions": positions})) | |
| else: | |
| print(f"[layer-swap] computing positions for layer {layer}") | |
| positions = compute_positions(sae.W_enc) | |
| p.write_text(json.dumps({"positions": positions})) | |
| state["sae"] = sae | |
| state["current_layer"] = layer | |
| state["positions"] = positions | |
| return { | |
| "ok": True, | |
| "layer": layer, | |
| "n_features": sae.n_features, | |
| "positions": positions, | |
| "from_cache": cached is not None, | |
| } | |
| def positions(): | |
| return {"positions": state["positions"]} | |
| def encode(req: EncodeRequest): | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| h_last = bucket["h"][0, -1].unsqueeze(0) | |
| z = sae.encode(h_last)[0] | |
| nz = z.nonzero(as_tuple=False).flatten() | |
| vals = z[nz] | |
| order = vals.argsort(descending=True)[:req.top_n] | |
| top = [{"id": int(nz[i].item()), "act": float(vals[i].item())} for i in order] | |
| return {"top": top, "n_features": sae.n_features} | |
| class EncodeFullRequest(BaseModel): | |
| prompt: str | |
| top_n: int = 16 # number of feature ROWS to return in the heatmap | |
| def encode_full(req: EncodeFullRequest): | |
| """Return a per-token feature activation grid for a single prompt. | |
| Picks the top_n features ranked by *mean activation across all token | |
| positions* (matches the official app.py heatmap definition), then returns | |
| each feature's activation at every token position. Activations that | |
| didn't make TopK at a given position are zero. | |
| """ | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| h = bucket["h"][0] # (seq_len, d_model) | |
| z = sae.encode(h) # (seq_len, n_features) sparse TopK | |
| seq_len = z.shape[0] | |
| # Token strings for column headers | |
| ids = inputs["input_ids"][0].tolist() | |
| tokens = [tokenizer.decode([t], skip_special_tokens=False) for t in ids] | |
| # Rank features by mean activation across all positions | |
| mean_per_feat = z.mean(dim=0) | |
| top_vals, top_idx = mean_per_feat.topk(min(int(req.top_n), sae.n_features)) | |
| grid = z[:, top_idx] # (seq_len, top_n) | |
| return { | |
| "tokens": tokens, | |
| "feature_ids": [int(i.item()) for i in top_idx], | |
| "mean_acts": [float(v.item()) for v in top_vals], | |
| # grid: outer list = features, inner list = positions (transposed for | |
| # natural row-per-feature rendering in the UI) | |
| "grid": [[float(grid[p, f].item()) for p in range(seq_len)] | |
| for f in range(grid.shape[1])], | |
| "seq_len": seq_len, | |
| "n_features": sae.n_features, | |
| } | |
| class EncodeBatchRequest(BaseModel): | |
| prompts: list[str] | |
| top_n: int = 20 # top features per prompt to return individually | |
| def encode_batch(req: EncodeBatchRequest): | |
| """Encode N prompts and return per-sample top features + corpus-level stats. | |
| For each prompt: encode the last-token residual through the SAE, return | |
| its top_n firing features. Corpus-level: union of features that fired | |
| at all, with per-feature firing rate (fraction of prompts where it | |
| appeared) and mean activation. | |
| """ | |
| if not req.prompts: | |
| return {"per_sample": [], "corpus_features": [], "n_features": state["sae"].n_features} | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| per_sample = [] | |
| union_act_sum: dict[int, float] = {} | |
| union_count: dict[int, int] = {} | |
| for idx, p in enumerate(req.prompts): | |
| inputs = tokenizer(p, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| h_last = bucket["h"][0, -1].unsqueeze(0) | |
| z = sae.encode(h_last)[0] | |
| nz = z.nonzero(as_tuple=False).flatten() | |
| vals = z[nz] | |
| order = vals.argsort(descending=True) | |
| top_idx = nz[order][:req.top_n] | |
| top = [{"id": int(top_idx[i].item()), "act": float(z[top_idx[i]].item())} | |
| for i in range(len(top_idx))] | |
| per_sample.append({ | |
| "i": idx, | |
| "preview": p[:80] + ("…" if len(p) > 80 else ""), | |
| "len": len(p), | |
| "top": top, | |
| "n_active": int(len(nz)), | |
| }) | |
| # Union stats over ALL nonzero features, not just top | |
| for fid, v in zip(nz.tolist(), vals.tolist()): | |
| union_count[fid] = union_count.get(fid, 0) + 1 | |
| union_act_sum[fid] = union_act_sum.get(fid, 0.0) + float(v) | |
| n = len(req.prompts) | |
| corpus = [] | |
| for fid, cnt in union_count.items(): | |
| corpus.append({ | |
| "id": fid, | |
| "fire_rate": cnt / n, | |
| "mean_act": union_act_sum[fid] / cnt, | |
| "n_samples": cnt, | |
| }) | |
| # Sort by fire_rate desc then mean_act desc | |
| corpus.sort(key=lambda r: (-r["fire_rate"], -r["mean_act"])) | |
| return { | |
| "per_sample": per_sample, | |
| "corpus_features": corpus[:200], # cap to 200 most frequent | |
| "n_features": sae.n_features, | |
| "n_samples": n, | |
| } | |
| class CompareBatchRequest(BaseModel): | |
| prompts_a: list[str] | |
| prompts_b: list[str] | |
| top_n: int = 30 | |
| def compare_batch(req: CompareBatchRequest): | |
| """Differential feature mining between two prompt sets. | |
| For each set: encode all prompts, compute per-feature firing rate | |
| (fraction of prompts where the feature fires) and mean activation. | |
| Rank features by |fire_rate_A − fire_rate_B|. | |
| Returns top features that distinguish A from B. | |
| """ | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| def _encode_set(prompts): | |
| n_feats = sae.n_features | |
| rate = torch.zeros(n_feats, dtype=torch.float32) | |
| acts = torch.zeros(n_feats, dtype=torch.float32) | |
| for p in prompts: | |
| inputs = tokenizer(p, return_tensors="pt").to(model.device) | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| h_last = bucket["h"][0, -1].unsqueeze(0) | |
| z = sae.encode(h_last)[0].detach().to("cpu", torch.float32) | |
| rate += (z != 0).float() | |
| acts += z | |
| if prompts: | |
| rate /= len(prompts) | |
| acts /= len(prompts) | |
| return rate, acts | |
| rate_a, acts_a = _encode_set(req.prompts_a) | |
| rate_b, acts_b = _encode_set(req.prompts_b) | |
| diff = (rate_a - rate_b).abs() | |
| top_vals, top_idx = diff.topk(min(int(req.top_n), sae.n_features)) | |
| rows = [] | |
| for v, fid in zip(top_vals.tolist(), top_idx.tolist()): | |
| rows.append({ | |
| "id": int(fid), | |
| "diff": float(v), | |
| "rate_a": float(rate_a[fid]), | |
| "rate_b": float(rate_b[fid]), | |
| "act_a": float(acts_a[fid]), | |
| "act_b": float(acts_b[fid]), | |
| "winner": "a" if rate_a[fid] >= rate_b[fid] else "b", | |
| }) | |
| return {"top_diff": rows, "n_a": len(req.prompts_a), "n_b": len(req.prompts_b), | |
| "n_features": sae.n_features} | |
| class SynthRequest(BaseModel): | |
| seed_prompts: list[str] | |
| steering: list[SteerSpec] = [] | |
| max_new_tokens: int = 40 | |
| def synth_batch(req: SynthRequest): | |
| """Bulk steered synthesis: run steered generate over N seed prompts. | |
| Useful for the data-centric synthesis workflow: produce K examples | |
| that fire feature F at strength α, for downstream training data. | |
| """ | |
| if not req.seed_prompts: | |
| return {"results": []} | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| results = [] | |
| for seed in req.seed_prompts: | |
| inputs = tokenizer(seed, return_tensors="pt").to(model.device) | |
| with _hook_stack(layer_module, sae, req.steering): | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=req.max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(out[0], skip_special_tokens=True) | |
| results.append({"seed": seed, "text": text}) | |
| return {"results": results} | |
| def _extract_per_token_probs(gen_out, prompt_len, tokenizer, top_k): | |
| """Build per-step probabilities + top-K candidate strings.""" | |
| new_ids = gen_out.sequences[0][prompt_len:].tolist() | |
| if not new_ids: | |
| return [] | |
| rows = [] | |
| for step, score_t in enumerate(gen_out.scores): | |
| probs = torch.softmax(score_t[0].float(), dim=-1) | |
| chosen_id = new_ids[step] | |
| chosen_prob = float(probs[chosen_id].item()) | |
| top_vals, top_ids = probs.topk(min(top_k, probs.shape[0])) | |
| top_ids_list = top_ids.tolist() | |
| # Decode one batch (chosen + topK) to limit tokenizer overhead | |
| decoded_chosen = tokenizer.decode([chosen_id], skip_special_tokens=False) | |
| decoded_top = tokenizer.batch_decode([[t] for t in top_ids_list], skip_special_tokens=False) | |
| topk = [] | |
| for tid, tv, ts in zip(top_ids_list, top_vals.tolist(), decoded_top): | |
| topk.append({"tok": ts, "prob": float(tv), "is_chosen": tid == chosen_id}) | |
| # If the chosen token wasn't in top-K, append it explicitly | |
| if chosen_id not in top_ids_list: | |
| topk.append({"tok": decoded_chosen, "prob": chosen_prob, "is_chosen": True}) | |
| rows.append({"tok": decoded_chosen, "prob": chosen_prob, "topk": topk}) | |
| return rows | |
| def generate(req: GenerateRequest): | |
| model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"] | |
| layer_module = state["layers"][state["current_layer"]] | |
| inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) | |
| prompt_len = int(inputs["input_ids"].shape[1]) | |
| base_acts = {} | |
| if req.steering: | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| z_base = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0] | |
| for s in req.steering: | |
| base_acts[s.id] = float(z_base[s.id].item()) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=req.max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| if req.return_probs: | |
| gen_kwargs["return_dict_in_generate"] = True | |
| gen_kwargs["output_scores"] = True | |
| with _hook_stack(layer_module, sae, req.steering, prompt_len=prompt_len): | |
| with torch.no_grad(): | |
| out = model.generate(**gen_kwargs) | |
| seq = out.sequences[0] if req.return_probs else out[0] | |
| text = tokenizer.decode(seq, skip_special_tokens=True) | |
| per_token_probs = (_extract_per_token_probs(out, prompt_len, tokenizer, req.topk_display) | |
| if req.return_probs else None) | |
| steered_acts = {} | |
| if req.steering: | |
| with torch.no_grad(), _capture_at(layer_module) as bucket: | |
| model(**inputs) | |
| z_steered = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0] | |
| for s in req.steering: | |
| steered_acts[s.id] = float(z_steered[s.id].item()) | |
| verifier = [ | |
| {"id": s.id, "alpha": s.alpha, | |
| "positions": s.positions, | |
| "output_only": s.output_only, | |
| "base": base_acts.get(s.id, 0.0), | |
| "steered": steered_acts.get(s.id, 0.0)} | |
| for s in req.steering | |
| ] | |
| resp = {"text": text, "verifier": verifier, "prompt_len": prompt_len} | |
| if per_token_probs is not None: | |
| resp["tokens"] = per_token_probs | |
| return resp | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |