""" neuron_steer.py - Neuron Circuit Discovery and Steering for Language Models LRP rules for linearized backward attribution: 1. LN-rule: RMSNorm coefficient (weight * rsqrt) detached but preserved in backward 2. AH-rule: Eager attention (no SDPA/Flash) for full autograd through Q/K/V/O 3. Half-rule: Shapley attribution for gate*up elementwise multiply in MLP Core insight: ~0.1% of MLP neurons form complete circuits. No SAE needed. Attribution via single forward+backward pass. Usage: steerer = NeuronSteerer("meta-llama/Llama-3.1-8B-Instruct") circuit = steerer.discover_circuit("What is the capital of Texas?", " Austin") steered = steerer.steer_and_generate("What is the capital of Texas?", circuit, multiplier=0.0) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional, Dict, NamedTuple, Set from contextlib import contextmanager, nullcontext from dataclasses import dataclass, field from collections import defaultdict # ============================================================ # Data Structures # ============================================================ class NeuronIdx(NamedTuple): """Identifies a specific MLP neuron activation.""" layer: int position: int neuron: int # ============================================================ # Universal Neuron Blacklists # Hard-coded from TransluceAI circuits repo (jvp.py) for Llama-3.1-8B. # These fire universally across tasks, not task-specific. # Format: (layer, neuron) - position-independent. # ============================================================ def _get_model_layers(model): """Get decoder layers from any model architecture (Llama, Qwen, Gemma4, etc.).""" if hasattr(model.model, 'layers'): return model.model.layers elif hasattr(model.model, 'language_model') and hasattr(model.model.language_model, 'layers'): return model.model.language_model.layers else: raise AttributeError( f"Cannot find layers in model architecture: {type(model.model).__name__}. " f"Supported: .model.layers or .model.language_model.layers" ) BLACKLIST_LLAMA3_8B = { (23, 306), (20, 3972), (18, 7417), (16, 1241), (13, 4208), (11, 11321), (10, 11570), (9, 4255), (7, 6673), (6, 5866), (5, 7012), (2, 4786), } def detect_universal_neurons( model, tokenizer, device="cuda", n_prompts: int = 20, top_k: int = 50, threshold_fraction: float = 0.8, ): """Auto-detect universal neurons by finding neurons that appear in top-k attribution across diverse prompts. Returns set of (layer, neuron) tuples. """ diverse_prompts = [ "The capital of France is", "Once upon a time there was a", "The best programming language is", "In the year 2024, the world", "The key to the cabinets", "How do I bake a cake?", "What is photosynthesis?", "The CEO of Apple is", "My favorite color is", "The largest ocean on Earth is", "Yesterday I went to the", "The speed of light is approximately", "In machine learning, a neural network", "The president of the United States", "Water freezes at a temperature of", "The meaning of life is", "To solve this math problem,", "The Great Wall of China was", "An electron has a charge of", "The chemical formula for water is", ][:n_prompts] # Count how many prompts each neuron appears in from collections import Counter neuron_counts: Counter = Counter() for prompt in diverse_prompts: input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Collect activations via hooks (no linearization needed) layer_acts = {} hooks = [] for i, layer in enumerate(_get_model_layers(model)): def make_hook(layer_idx): def hook_fn(module, args): layer_acts[layer_idx] = args[0][0, -1].detach() return hook_fn h = layer.mlp.down_proj.register_forward_pre_hook(make_hook(i)) hooks.append(h) try: with torch.no_grad(): model(input_ids) finally: for h in hooks: h.remove() # Find top-k activated neurons per layer for layer_idx, act in layer_acts.items(): top_vals, top_idxs = act.abs().topk(min(top_k, act.shape[0])) for idx in top_idxs: neuron_counts[(layer_idx, idx.item())] += 1 # Neurons that appear in >= threshold_fraction of prompts are universal threshold = int(n_prompts * threshold_fraction) universal = {ln for ln, count in neuron_counts.items() if count >= threshold} return universal @dataclass class Circuit: """A set of neurons with their attributions.""" neurons: Dict[NeuronIdx, float] prompt: str target_token: str total_logit_diff: float def top(self, k: int = 20) -> List[Tuple[NeuronIdx, float]]: """Return top-k neurons by attribution magnitude.""" return sorted(self.neurons.items(), key=lambda x: abs(x[1]), reverse=True)[:k] def by_layer(self) -> Dict[int, List[Tuple[NeuronIdx, float]]]: """Group neurons by layer.""" result: Dict[int, list] = {} for nidx, attr in self.neurons.items(): result.setdefault(nidx.layer, []).append((nidx, attr)) for l in result: result[l].sort(key=lambda x: abs(x[1]), reverse=True) return result def unique_neurons(self) -> Dict[int, Set[int]]: """Get unique neuron indices per layer (collapsing positions).""" result: Dict[int, Set[int]] = {} for nidx in self.neurons: result.setdefault(nidx.layer, set()).add(nidx.neuron) return result def save(self, path: str): """Save circuit to file.""" import json data = { "neurons": {f"{n.layer},{n.position},{n.neuron}": v for n, v in self.neurons.items()}, "prompt": self.prompt, "target_token": self.target_token, "total_logit_diff": self.total_logit_diff, } with open(path, "w") as f: json.dump(data, f, indent=2) @classmethod def load(cls, path: str) -> "Circuit": """Load circuit from file.""" import json with open(path) as f: data = json.load(f) neurons = {} for key, val in data["neurons"].items(): l, p, n = key.split(",") neurons[NeuronIdx(int(l), int(p), int(n))] = val return cls(neurons=neurons, prompt=data["prompt"], target_token=data["target_token"], total_logit_diff=data["total_logit_diff"]) def summary(self) -> str: lines = [ f"Circuit: {len(self.neurons)} neurons, logit_diff={self.total_logit_diff:.4f}", f"Prompt: {self.prompt[:80]}", f"Target: {self.target_token}", f"Layers touched: {sorted(set(n.layer for n in self.neurons))}", ] by_layer = self.by_layer() for l in sorted(by_layer.keys()): neurons = by_layer[l] lines.append(f" L{l}: {len(neurons)} neurons, top={neurons[0][1]:.6f}") return "\n".join(lines) # ============================================================ # LRP Rule 1: LN-rule (RMSNorm) # Forward = real RMSNorm, Backward = identity through normalization # ============================================================ class LinearizedRMSNorm(nn.Module): """Wraps RMSNorm with LN-rule. Forward = real RMSNorm value. Backward = grad * (weight * rsqrt(mean(x²) + eps)), where the coefficient is DETACHED (treated as constant, but its per-token value is preserved). coeff = weight * rsqrt(mean(x²) + eps), detached, then y = x * coeff. Since coeff is detached, backward = grad * coeff. """ def __init__(self, original): super().__init__() self.weight = original.weight # Llama uses variance_epsilon, Qwen/Gemma/Mistral use eps if hasattr(original, 'variance_epsilon'): self.eps = original.variance_epsilon elif hasattr(original, 'eps'): self.eps = original.eps else: self.eps = 1e-6 # safe default self._original = original def forward(self, x): # Compute normalization coefficient: weight * rsqrt(mean(x²) + eps) # DETACH it so backward treats it as constant (LN-rule) input_dtype = x.dtype variance = x.float().pow(2).mean(-1, keepdim=True) coeff = self.weight.float() * torch.rsqrt(variance + self.eps) coeff = coeff.detach().to(input_dtype) # key: treat as constant in backward return x * coeff # ============================================================ # LRP Rule 3: Half-rule (Gated MLP) # Shapley attribution for elementwise multiply: each factor gets 50% gradient # ============================================================ class _HalfRuleMultiply(torch.autograd.Function): @staticmethod def forward(ctx, a, b): ctx.save_for_backward(a, b) return a * b @staticmethod def backward(ctx, grad_output): a, b = ctx.saved_tensors # Shapley value: each factor gets half credit return grad_output * b * 0.5, grad_output * a * 0.5 class LinearizedMLP(nn.Module): """Wraps LlamaMLP with detached sigmoid + half-rule. only the sigmoid is detached while the linear component of SiLU (x * sigmoid(x)) keeps gradient flow, then the half-rule distributes credit evenly between gate and up projections. Standard Llama MLP: hidden = SiLU(gate_proj(x)) * up_proj(x) output = down_proj(hidden) Linearized version: gate = gate_proj(x) sigmoid_gate = sigmoid(gate).detach() # treat sigmoid as constant gate_act = gate * sigmoid_gate # linearized SiLU hidden = HalfRule(gate_act, up_proj(x)) # Shapley attribution output = down_proj(hidden) The `hidden` tensor (input to down_proj) = "neuron activation". This is what we attribute and steer. """ def __init__(self, original): super().__init__() self.gate_proj = original.gate_proj self.up_proj = original.up_proj self.down_proj = original.down_proj self._original = original self.neuron_act = None # saved during forward for attribution def forward(self, x): gate = self.gate_proj(x) up = self.up_proj(x) # Linearized SiLU: detach the sigmoid coefficient sigmoid_gate = torch.sigmoid(gate).detach() gate_act = gate * sigmoid_gate # Half-rule on elementwise multiply hidden = _HalfRuleMultiply.apply(gate_act, up) # Save neuron activation for attribution hidden.retain_grad() self.neuron_act = hidden return self.down_proj(hidden) # ============================================================ # Model Linearization # ============================================================ def _linearize_model(model): """Apply LRP rules to a Llama-family model. Rule 1 (LN): Replace RMSNorm with LinearizedRMSNorm (detached coeff) Rule 2 (AH): Force eager attention (no SDPA/Flash) for autograd compatibility Rule 3 (Half): Replace MLP with LinearizedMLP (linearized SiLU + half-rule) Returns dict for restoration. """ originals = {"modules": {}, "hooks": []} # Rule 2: Force eager attention # This ensures gradient flows through all attention paths including Q/K. originals["attn_impl"] = model.config._attn_implementation model.config._attn_implementation = "eager" # Rule 1: RMSNorm → LinearizedRMSNorm originals["modules"]["model.norm"] = model.model.norm model.model.norm = LinearizedRMSNorm(model.model.norm) for i, layer in enumerate(_get_model_layers(model)): # Input layernorm originals["modules"][f"layer.{i}.input_layernorm"] = layer.input_layernorm layer.input_layernorm = LinearizedRMSNorm(layer.input_layernorm) # Post-attention layernorm originals["modules"][f"layer.{i}.post_attention_layernorm"] = layer.post_attention_layernorm layer.post_attention_layernorm = LinearizedRMSNorm(layer.post_attention_layernorm) # Rule 2: AH-rule — replace SDPA/Flash with eager attention # (not fused SDPA/Flash) for autograd compatibility. Gradient DOES flow # through Q/K — the class name is misleading. We match their behavior. # Rule 3: MLP → LinearizedMLP originals["modules"][f"layer.{i}.mlp"] = layer.mlp layer.mlp = LinearizedMLP(layer.mlp) return originals def _restore_model(model, originals): """Restore original model modules.""" # Remove backward hooks for hook in originals["hooks"]: hook.remove() # Restore attention implementation if "attn_impl" in originals: model.config._attn_implementation = originals["attn_impl"] # Restore modules model.model.norm = originals["modules"]["model.norm"] for i, layer in enumerate(_get_model_layers(model)): layer.input_layernorm = originals["modules"][f"layer.{i}.input_layernorm"] layer.post_attention_layernorm = originals["modules"][f"layer.{i}.post_attention_layernorm"] layer.mlp = originals["modules"][f"layer.{i}.mlp"] @contextmanager def linearized(model): """Context manager: apply LRP rules for attribution, restore after.""" originals = _linearize_model(model) try: yield model finally: _restore_model(model, originals) # ============================================================ # Attribution # ============================================================ def compute_attribution( model, input_ids: torch.Tensor, target_token_id: int, counterfactual_token_id: Optional[int] = None, position: int = -1, top_k_per_layer: int = 200, filter_bos: bool = True, last_n_positions: Optional[int] = None, blacklist_layers: Optional[Set[int]] = None, blacklist_neurons: Optional[Set[Tuple[int, int]]] = None, target_only: bool = False, verbose: bool = False, ) -> Tuple[Dict[NeuronIdx, float], float]: """Compute per-neuron attribution. Args: model: Linearized Llama model (inside `linearized()` context) input_ids: [1, T] input token ids target_token_id: Token to attribute toward counterfactual_token_id: Alternative token for logit diff (None = auto or target_only) position: Token position for logit measurement (default: last) top_k_per_layer: Keep top-k neurons per layer per position (sparsification) filter_bos: If True, exclude position 0 (BOS) neurons last_n_positions: If set, only keep neurons from the last N token positions. blacklist_layers: Set of layer indices to exclude entirely blacklist_neurons: Set of (layer, neuron) tuples to exclude target_only: If True, backward from target logit alone. If False and no counterfactual given, auto-detects 2nd highest logit. Use target_only=True for percentage_threshold selection. verbose: Print diagnostic info about attribution distribution Returns: (attributions dict, metric_value scalar) metric_value is target_logit when target_only=True, else logit_diff """ blacklist_layers = blacklist_layers or set() blacklist_neurons = blacklist_neurons or set() model.eval() model.zero_grad() # Clear any saved neuron activations for layer in _get_model_layers(model): if hasattr(layer.mlp, "neuron_act"): layer.mlp.neuron_act = None with torch.enable_grad(): outputs = model(input_ids) logits = outputs.logits[0, position] # [vocab_size] target_logit = logits[target_token_id] if target_only: metric = target_logit elif counterfactual_token_id is None: sorted_logits, sorted_ids = logits.sort(descending=True) if sorted_ids[0].item() == target_token_id: counterfactual_logit = sorted_logits[1] else: counterfactual_logit = sorted_logits[0] metric = target_logit - counterfactual_logit else: counterfactual_logit = logits[counterfactual_token_id] metric = target_logit - counterfactual_logit # Backward through linearized model metric.backward() # Collect attributions from saved neuron activations attributions = {} layer_stats = {} # diagnostic info for i, layer in enumerate(_get_model_layers(model)): if i in blacklist_layers: continue mlp = layer.mlp if not hasattr(mlp, "neuron_act") or mlp.neuron_act is None: continue if mlp.neuron_act.grad is None: continue act = mlp.neuron_act.detach() # [1, T, intermediate_size] grad = mlp.neuron_act.grad # [1, T, intermediate_size] # Attribution = gradient * activation (element-wise) attr = (grad * act)[0] # [T, intermediate_size] T = attr.shape[0] # NaN-safe statistics (exclude NaN from sums) valid_mask = ~torch.isnan(attr) valid_attr = attr[valid_mask] if valid_attr.numel() > 0: layer_total = valid_attr.abs().sum().item() layer_max = valid_attr.abs().max().item() nan_frac = 1.0 - valid_mask.float().mean().item() else: layer_total = 0.0 layer_max = 0.0 nan_frac = 1.0 layer_stats[i] = {"total": layer_total, "max": layer_max, "nan_frac": nan_frac} if last_n_positions is not None: start_pos = max(0, T - last_n_positions) elif filter_bos: start_pos = 1 else: start_pos = 0 for p in range(start_pos, T): pos_attr = attr[p] abs_attr = pos_attr.abs() # NaN-safe topk: replace NaN with 0 so they don't crowd out valid values nan_mask = torch.isnan(abs_attr) if nan_mask.any(): abs_attr = abs_attr.clone() abs_attr[nan_mask] = 0.0 # Keep top-k neurons at this position k = min(top_k_per_layer, abs_attr.shape[0]) top_vals, top_idxs = abs_attr.topk(k) for val, idx in zip(top_vals, top_idxs): if val.item() > 1e-8: n = idx.item() if (i, n) in blacklist_neurons: continue nidx = NeuronIdx(layer=i, position=p, neuron=n) attributions[nidx] = pos_attr[idx].item() # Free GPU memory - clear saved activations after collection for layer in _get_model_layers(model): if hasattr(layer.mlp, "neuron_act"): layer.mlp.neuron_act = None if verbose: print(f" Attribution distribution by layer:") has_nan = False for l in sorted(layer_stats.keys()): s = layer_stats[l] nan_str = f" [NaN: {s['nan_frac']:.1%}]" if s['nan_frac'] > 0.01 else "" print(f" L{l:2d}: total={s['total']:.4f}, max={s['max']:.4f}{nan_str}") if s['nan_frac'] > 0.01: has_nan = True total_attr = sum(abs(v) for v in attributions.values()) print(f" Total (filtered): {total_attr:.4f}, {len(attributions)} neurons") if has_nan: print(f" WARNING: NaN in gradients detected. LRP rules may not be compatible with this model.") return attributions, metric.item() def select_circuit( attributions: Dict[NeuronIdx, float], method: str = "threshold", threshold: float = 0.005, top_k: Optional[int] = None, per_layer_topk: Optional[int] = None, reference_value: Optional[float] = None, ) -> Dict[NeuronIdx, float]: """Select circuit neurons from attributions. Methods: 'threshold': Select neurons until cumulative |attribution| >= threshold * total 'topk': Select top-k neurons by |attribution| (globally) 'percentage': Keep neurons with INDIVIDUAL |attribution| >= threshold * |reference_value|. When percentage_threshold=0.005, keeps neurons contributing >= 0.5% of the logit diff. This filters noise while preserving all significant neurons. Requires reference_value (typically logit_diff). 'per_layer_topk': Select top-N from EACH layer, then take global top_k. Essential for models like Qwen where early layers dominate by 10^10. """ if not attributions: return {} total = sum(abs(v) for v in attributions.values()) if total < 1e-10: return {} if method == "per_layer_topk" and per_layer_topk is not None: # Group by layer, take top-N from each, then global top_k by_layer: Dict[int, List] = {} for nidx, attr in attributions.items(): by_layer.setdefault(nidx.layer, []).append((nidx, attr)) selected = {} for layer_idx, neurons in by_layer.items(): neurons.sort(key=lambda x: abs(x[1]), reverse=True) for nidx, attr in neurons[:per_layer_topk]: selected[nidx] = attr # If also top_k specified, trim globally if top_k is not None and len(selected) > top_k: sorted_sel = sorted(selected.items(), key=lambda x: abs(x[1]), reverse=True) selected = dict(sorted_sel[:top_k]) return selected sorted_attrs = sorted(attributions.items(), key=lambda x: abs(x[1]), reverse=True) if method == "topk" and top_k is not None: return dict(sorted_attrs[:top_k]) if method == "percentage" and reference_value is not None: abs_threshold = threshold * abs(reference_value) selected = {nidx: attr for nidx, attr in attributions.items() if abs(attr) >= abs_threshold} return selected # Default: cumulative threshold selected = {} cumulative = 0.0 for nidx, attr in sorted_attrs: selected[nidx] = attr cumulative += abs(attr) if cumulative >= threshold * total: break return selected # ============================================================ # Steering # ============================================================ @contextmanager def steer_neurons( model, neurons: Dict[NeuronIdx, float], multiplier: float = 0.0, all_positions: bool = True, ): """Apply steering hooks to specific neurons during forward pass. Modifies neuron activations (input to down_proj) by multiplying with `multiplier`. multiplier=0.0 → ablate, 1.0 → no change, 2.0 → amplify If all_positions=True, steers the neuron at ALL positions (for generation). If False, only steers at the specific positions from the circuit. """ hooks = [] if all_positions: # Group by layer, collect unique neuron indices by_layer: Dict[int, List[int]] = {} for nidx in neurons: by_layer.setdefault(nidx.layer, set()).add(nidx.neuron) by_layer = {l: sorted(ns) for l, ns in by_layer.items()} for layer_idx, neuron_indices in by_layer.items(): idx_tensor = torch.tensor(neuron_indices, dtype=torch.long) def make_hook(idx_t): def pre_hook(module, args): x = args[0].clone() device_idx = idx_t.to(x.device) x[:, :, device_idx] *= multiplier return (x,) return pre_hook hook = _get_model_layers(model)[layer_idx].mlp.down_proj.register_forward_pre_hook( make_hook(idx_tensor) ) hooks.append(hook) else: # Group by (layer, position) by_layer_pos: Dict[Tuple[int, int], List[int]] = {} for nidx in neurons: key = (nidx.layer, nidx.position) by_layer_pos.setdefault(key, []).append(nidx.neuron) # Group by layer for efficient hooking layer_pos_map: Dict[int, Dict[int, List[int]]] = {} for (l, p), ns in by_layer_pos.items(): layer_pos_map.setdefault(l, {})[p] = ns for layer_idx, pos_map in layer_pos_map.items(): def make_hook(pm): def pre_hook(module, args): x = args[0].clone() for pos, neuron_indices in pm.items(): idx_t = torch.tensor(neuron_indices, dtype=torch.long, device=x.device) x[:, pos, idx_t] *= multiplier return (x,) return pre_hook hook = _get_model_layers(model)[layer_idx].mlp.down_proj.register_forward_pre_hook( make_hook(pos_map) ) hooks.append(hook) try: yield model finally: for hook in hooks: hook.remove() # ============================================================ # High-Level API # ============================================================ class NeuronSteerer: """End-to-end neuron circuit discovery and steering. Pipeline: 1. Load model (eager attention for compatibility) 2. discover_circuit(): linearize → forward → backward → select neurons 3. steer_and_generate(): hook neurons → generate with modified activations """ def __init__(self, model_name: str, device: str = "cuda", dtype=torch.bfloat16, auto_blacklist: bool = True, max_memory: dict = None): from transformers import AutoModelForCausalLM, AutoTokenizer print(f"Loading {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", attn_implementation="eager", dtype=dtype, **({"max_memory": max_memory} if max_memory else {}), ) self.model.eval() self.device = device self.model_name = model_name self.is_instruct = "instruct" in model_name.lower() or "chat" in model_name.lower() print(f"Loaded {model_name} on {device} (instruct={self.is_instruct})") # Auto-detect layer path for different architectures (Llama, Qwen, Gemma4, etc.) if hasattr(self.model.model, 'layers'): self._layers_ref = self.model.model.layers elif hasattr(self.model.model, 'language_model') and hasattr(self.model.model.language_model, 'layers'): self._layers_ref = self.model.model.language_model.layers else: raise AttributeError( f"Cannot find layers in model architecture: {type(self.model.model).__name__}. " f"Supported: .model.layers or .model.language_model.layers" ) print(f" Layers: {len(self._layers_ref)} (via {'model.layers' if hasattr(self.model.model, 'layers') else 'model.language_model.layers'})") # Auto-detect config path for multimodal models (Gemma4, etc.) if hasattr(self.model.config, 'text_config'): self._text_config = self.model.config.text_config else: self._text_config = self.model.config # Feature cache: name -> Circuit for reuse across steer() calls self._feature_cache: Dict[str, Circuit] = {} # Universal neuron blacklist (model-conditional) is_llama_8b = "llama" in model_name.lower() and ("8b" in model_name.lower() or "8B" in model_name) if is_llama_8b: self.blacklist: Set[Tuple[int, int]] = set(BLACKLIST_LLAMA3_8B) known_str = f"{len(BLACKLIST_LLAMA3_8B)} from TransluceAI" else: self.blacklist: Set[Tuple[int, int]] = set() known_str = "0 known (non-Llama-8B model)" if auto_blacklist: print("Detecting universal neurons...") detected = detect_universal_neurons( self.model, self.tokenizer, device, n_prompts=10, top_k=20, threshold_fraction=0.8, ) new_detected = detected - self.blacklist self.blacklist |= detected print(f" Blacklist: {len(self.blacklist)} universal neurons " f"({known_str} + {len(new_detected)} new auto-detected)") def _format_prompt(self, prompt: str, seed_response: str = "") -> str: """Format prompt for instruct models using chat template.""" if self.is_instruct and hasattr(self.tokenizer, "apply_chat_template"): messages = [{"role": "user", "content": prompt}] formatted = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return formatted + seed_response return prompt + seed_response def discover_circuit( self, prompt: str, target_token: str, counterfactual_token: Optional[str] = None, threshold: float = 0.005, top_k: Optional[int] = None, selection_method: Optional[str] = None, seed_response: str = "", filter_bos: bool = True, filter_infrastructure: bool = False, last_n_positions: Optional[int] = None, blacklist_neurons: Optional[Set[Tuple[int, int]]] = None, use_chat_template: bool = True, verbose: bool = False, ) -> Circuit: """Discover the neuron circuit for predicting target_token. Selection methods: top_k=N: Select exactly N neurons by |attribution| selection_method='percentage': keep neurons with |attribution| >= threshold * |logit_diff|. Default threshold=0.005. Default: cumulative threshold Args: prompt: Input text (auto-formatted for instruct models) target_token: Target output token (e.g., " Austin") counterfactual_token: Alternative token (auto if None) threshold: Attribution threshold (meaning depends on selection_method) top_k: Select exactly top_k neurons (overrides threshold) selection_method: 'percentage' for individual neuron threshold seed_response: Text to append before target (e.g., "Answer:") filter_bos: Filter out BOS position neurons filter_infrastructure: Filter out L0-L1, or pass set of layer indices use_chat_template: Use chat template for instruct models (False for raw completion like SVA) verbose: Print attribution diagnostics """ if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) # Tokenization validation target_ids = self.tokenizer.encode(target_token, add_special_tokens=False) target_id = target_ids[-1] if len(target_ids) > 1 and verbose: print(f" WARNING: target '{target_token}' encodes to {len(target_ids)} tokens " f"{target_ids}. Using last token ({target_id}) for attribution.") cf_id = None if counterfactual_token: cf_ids = self.tokenizer.encode(counterfactual_token, add_special_tokens=False) cf_id = cf_ids[-1] if len(cf_ids) > 1 and verbose: print(f" WARNING: counterfactual '{counterfactual_token}' encodes to {len(cf_ids)} tokens " f"{cf_ids}. Using last token ({cf_id}).") if cf_id == target_id: print(f" ERROR: target and counterfactual share first token ({target_id})! " f"Logit diff will be 0. Fix your token strings.") bl_layers = filter_infrastructure if isinstance(filter_infrastructure, set) else ({0, 1} if filter_infrastructure else set()) bl_neurons = blacklist_neurons if blacklist_neurons is not None else self.blacklist # Use target_only when doing percentage selection use_target_only = (selection_method == "percentage") with linearized(self.model): attributions, metric_value = compute_attribution( self.model, input_ids, target_id, cf_id, filter_bos=filter_bos, verbose=verbose, last_n_positions=last_n_positions, blacklist_layers=bl_layers, blacklist_neurons=bl_neurons, target_only=use_target_only, ) # Select circuit if top_k: method = "topk" elif selection_method == "percentage": method = "percentage" else: method = "threshold" circuit_neurons = select_circuit( attributions, method=method, threshold=threshold, top_k=top_k, reference_value=metric_value, ) return Circuit( neurons=circuit_neurons, prompt=formatted, target_token=target_token, total_logit_diff=metric_value, ) def discover_circuit_multi( self, prompts: List[str], target_tokens: List[str], counterfactual_tokens: Optional[List[str]] = None, threshold: float = 0.005, top_k: Optional[int] = None, selection_method: Optional[str] = None, seed_response: str = "", filter_bos: bool = True, filter_infrastructure: bool = False, last_n_positions: Optional[int] = None, blacklist_neurons: Optional[Set[Tuple[int, int]]] = None, batch_aggregation: str = "mean", use_chat_template: bool = True, verbose: bool = False, precomputed_attributions: Optional[Tuple[Dict, float]] = None, return_raw_attributions: bool = False, target_only: Optional[bool] = None, ) -> "Circuit | Tuple[Circuit, Dict, float]": """Discover circuit over multiple prompts. batch_aggregation modes: 'mean': Average attributions across all prompts (default) 'any': Keep neuron if it's important in ANY prompt (union) Preserves prompt-specific neurons selection_method='percentage' + threshold=0.005: keep neurons with |attr| >= 0.5% of |reference_value|. When no counterfactual tokens given, auto-enables target_only (backprop from target logit alone, not logit_diff) precomputed_attributions: (aggregated_dict, avg_ld) from a prior call with return_raw_attributions=True. Skips all LRP computation — only does selection. return_raw_attributions: if True, returns (Circuit, aggregated_dict, avg_ld) so you can call again with different top_k without recomputing LRP. """ # Fast path: use precomputed attributions (skip all LRP) if precomputed_attributions is not None: aggregated, avg_ld = precomputed_attributions if top_k: circuit_neurons = select_circuit(aggregated, method="topk", top_k=top_k) elif selection_method == "percentage": circuit_neurons = select_circuit( aggregated, method="percentage", threshold=threshold, reference_value=avg_ld) else: circuit_neurons = select_circuit( aggregated, method="threshold", threshold=threshold) circuit = Circuit( neurons=circuit_neurons, prompt=f"[{len(prompts)} prompts, agg={batch_aggregation}]", target_token=str(target_tokens[:3]), total_logit_diff=avg_ld, ) if return_raw_attributions: return circuit, aggregated, avg_ld return circuit all_attributions: Dict[NeuronIdx, List[float]] = defaultdict(list) logit_diffs = [] bl_layers = filter_infrastructure if isinstance(filter_infrastructure, set) else ({0, 1} if filter_infrastructure else set()) bl_neurons = blacklist_neurons if blacklist_neurons is not None else self.blacklist # Use target_only when doing percentage selection or explicitly requested use_target_only = target_only if target_only is not None else (selection_method == "percentage") # For percentage + "any": apply threshold PER PROMPT # Each prompt gets its own threshold = percentage * that_prompt's_target_logit # Then union across prompts per_prompt_filter = (selection_method == "percentage" and batch_aggregation == "any") for i, (prompt, target) in enumerate(zip(prompts, target_tokens)): cf = counterfactual_tokens[i] if counterfactual_tokens else None if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) target_id = self.tokenizer.encode(target, add_special_tokens=False)[-1] cf_id = self.tokenizer.encode(cf, add_special_tokens=False)[-1] if cf else None with linearized(self.model): attrs, ld = compute_attribution( self.model, input_ids, target_id, cf_id, filter_bos=filter_bos, verbose=False, last_n_positions=last_n_positions, blacklist_layers=bl_layers, blacklist_neurons=bl_neurons, target_only=use_target_only, ) if per_prompt_filter: abs_thresh = threshold * abs(ld) filtered = {nidx: attr for nidx, attr in attrs.items() if abs(attr) >= abs_thresh} for nidx, attr in filtered.items(): all_attributions[nidx].append(attr) if verbose: print(f" Prompt {i+1}/{len(prompts)}: {len(filtered)} neurons " f"(of {len(attrs)} raw), ld={ld:.4f}, thresh={abs_thresh:.6f}") else: for nidx, attr in attrs.items(): all_attributions[nidx].append(attr) if verbose: print(f" Prompt {i+1}/{len(prompts)}: {len(attrs)} neurons, ld={ld:.4f}") logit_diffs.append(ld) # Aggregate attributions if batch_aggregation == "any": aggregated = {} for nidx, attr_list in all_attributions.items(): aggregated[nidx] = max(attr_list, key=abs) else: aggregated = {} for nidx, attr_list in all_attributions.items(): aggregated[nidx] = sum(attr_list) / len(prompts) avg_ld = sum(logit_diffs) / len(logit_diffs) # Select circuit if per_prompt_filter: # Already filtered per-prompt, just use what we have circuit_neurons = aggregated elif top_k: circuit_neurons = select_circuit( aggregated, method="topk", top_k=top_k) elif selection_method == "percentage": circuit_neurons = select_circuit( aggregated, method="percentage", threshold=threshold, reference_value=avg_ld) else: circuit_neurons = select_circuit( aggregated, method="threshold", threshold=threshold) circuit = Circuit( neurons=circuit_neurons, prompt=f"[{len(prompts)} prompts, agg={batch_aggregation}]", target_token=str(target_tokens[:3]), total_logit_diff=avg_ld, ) if return_raw_attributions: return circuit, aggregated, avg_ld return circuit def discover_contrastive( self, positive_prompts: List[str], negative_prompts: List[str], top_k: int = 200, filter_infrastructure: bool = True, verbose: bool = False, ) -> Circuit: """Discover neurons by contrasting activations between two prompt sets. This is better for behavioral steering (refusal, tone, style) where there's no clean target/counterfactual token pair. Runs all prompts through the model, collects MLP neuron activations at the last token position, then finds neurons with largest activation difference between positive and negative sets. Args: positive_prompts: Prompts exhibiting the target behavior (e.g., harmful prompts that get refused) negative_prompts: Prompts NOT exhibiting it (e.g., benign prompts that get answered) top_k: Number of neurons to select filter_infrastructure: Exclude L0-L1 """ # filter_infrastructure: True={0,1}, or pass a set like {0,1,2,3,4} for Qwen bl_layers = filter_infrastructure if isinstance(filter_infrastructure, set) else ({0, 1} if filter_infrastructure else set()) def collect_activations(prompts): """Run prompts and collect last-position neuron activations per layer. Uses forward pre-hooks on down_proj to capture the input (= neuron activations) without requiring linearization or gradients. """ all_acts = [] for prompt in prompts: formatted = self._format_prompt(prompt) input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) # Hook into down_proj to capture neuron activations layer_acts = {} hooks = [] for i, layer in enumerate(self._layers_ref): if i in bl_layers: continue def make_hook(layer_idx): def hook_fn(module, args): layer_acts[layer_idx] = args[0][0, -1].detach().cpu() return hook_fn h = layer.mlp.down_proj.register_forward_pre_hook(make_hook(i)) hooks.append(h) try: with torch.no_grad(): self.model(input_ids) finally: for h in hooks: h.remove() all_acts.append(layer_acts) return all_acts print(f" Collecting activations for {len(positive_prompts)} positive prompts...") pos_acts = collect_activations(positive_prompts) print(f" Collecting activations for {len(negative_prompts)} negative prompts...") neg_acts = collect_activations(negative_prompts) # Compute mean activation per neuron for each set all_layers = set() for acts in pos_acts + neg_acts: all_layers.update(acts.keys()) neurons_with_diff = {} for layer_idx in sorted(all_layers): pos_mean = torch.stack([a[layer_idx] for a in pos_acts if layer_idx in a]).mean(0) neg_mean = torch.stack([a[layer_idx] for a in neg_acts if layer_idx in a]).mean(0) diff = pos_mean - neg_mean # positive = more active in positive set for n in range(diff.shape[0]): d = diff[n].item() if abs(d) > 1e-6: nidx = NeuronIdx(layer=layer_idx, position=-1, neuron=n) neurons_with_diff[nidx] = d # Select top-k by absolute difference sorted_neurons = sorted(neurons_with_diff.items(), key=lambda x: abs(x[1]), reverse=True) circuit_neurons = dict(sorted_neurons[:top_k]) if verbose: print(f" Found {len(neurons_with_diff)} neurons with nonzero difference") by_layer_count = defaultdict(int) for nidx in circuit_neurons: by_layer_count[nidx.layer] += 1 for l in sorted(by_layer_count.keys()): print(f" L{l:2d}: {by_layer_count[l]} neurons") return Circuit( neurons=circuit_neurons, prompt=f"[contrastive: {len(positive_prompts)} pos vs {len(negative_prompts)} neg]", target_token="[contrastive]", total_logit_diff=0.0, ) # ============================================================ # CAA ↔ Neuron Circuit Connection (Novel) # ============================================================ def compute_control_vector( self, positive_prompts: List[str], negative_prompts: List[str], layer_idx: Optional[int] = None, seed_response: str = "", use_chat_template: bool = True, ) -> Dict[int, torch.Tensor]: """Compute a Contrastive Activation Addition (CAA) control vector. v = mean(activations_positive) - mean(activations_negative) at the residual stream after each MLP layer. Args: positive_prompts: Prompts that elicit target behavior negative_prompts: Prompts that elicit opposite behavior layer_idx: If set, only compute for this layer. Otherwise all layers. seed_response: Appended after prompt use_chat_template: Use chat template formatting Returns: Dict[layer_idx, control_vector] where each CV is [d_model] """ layers = [layer_idx] if layer_idx is not None else list(range(len(self._layers_ref))) def collect_residual(prompts): """Collect residual stream activations after MLP for each layer.""" all_acts = {l: [] for l in layers} for prompt in prompts: if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) # Hook to capture post-MLP residual at last token position captured = {} hooks = [] for l in layers: def make_hook(layer_idx): def hook_fn(module, input, output): # output is tuple or BaseModelOutput — extract hidden states hs = output[0] if isinstance(output, tuple) else output if hasattr(hs, 'last_hidden_state'): hs = hs.last_hidden_state captured[layer_idx] = hs[0, -1].detach().clone() return hook_fn h = self._layers_ref[l].register_forward_hook(make_hook(l)) hooks.append(h) try: with torch.no_grad(): self.model(input_ids) finally: for h in hooks: h.remove() for l in layers: if l in captured: all_acts[l].append(captured[l]) return {l: torch.stack(acts) for l, acts in all_acts.items() if acts} pos_acts = collect_residual(positive_prompts) neg_acts = collect_residual(negative_prompts) control_vectors = {} for l in layers: if l in pos_acts and l in neg_acts: cv = pos_acts[l].mean(dim=0) - neg_acts[l].mean(dim=0) control_vectors[l] = cv return control_vectors def compute_mlp_control_vector( self, positive_prompts: List[str], negative_prompts: List[str], layer_idx: Optional[int] = None, seed_response: str = "", use_chat_template: bool = True, ) -> Dict[int, torch.Tensor]: """Compute control vector from MLP outputs ONLY (not attention). Unlike compute_control_vector which captures full residual stream (attention + MLP), this hooks the MLP sublayer directly. Returns: Dict[layer_idx, mlp_control_vector] where each CV is [d_model] """ layers = [layer_idx] if layer_idx is not None else list(range(len(self._layers_ref))) def collect_mlp_output(prompts): all_acts = {l: [] for l in layers} for prompt in prompts: if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) captured = {} hooks = [] for l in layers: def make_hook(layer_idx): def hook_fn(module, input, output): # MLP output is a tensor, not tuple out = output[0] if isinstance(output, tuple) else output captured[layer_idx] = out[0, -1].detach().clone() return hook_fn h = self._layers_ref[l].mlp.register_forward_hook(make_hook(l)) hooks.append(h) try: with torch.no_grad(): self.model(input_ids) finally: for h in hooks: h.remove() for l in layers: if l in captured: all_acts[l].append(captured[l]) return {l: torch.stack(acts) for l, acts in all_acts.items() if acts} pos_acts = collect_mlp_output(positive_prompts) neg_acts = collect_mlp_output(negative_prompts) control_vectors = {} for l in layers: if l in pos_acts and l in neg_acts: cv = pos_acts[l].mean(dim=0) - neg_acts[l].mean(dim=0) control_vectors[l] = cv return control_vectors def compute_activation_weighted_cv( self, positive_prompts: List[str], negative_prompts: List[str], layer_idx: Optional[int] = None, seed_response: str = "", use_chat_template: bool = True, ) -> Dict[int, Dict[int, float]]: """Compute per-neuron control contributions weighted by actual activations. Instead of projecting a residual-level CV onto W_down columns (which loses information about which neurons actually fired), this directly captures the intermediate MLP activations (post gate*up, pre down_proj) and computes per-neuron behavioral differences. neuron_contribution[i] = mean(act_pos[i]) - mean(act_neg[i]) Returns: Dict[layer_idx, Dict[neuron_idx, activation_difference]] """ layers = [layer_idx] if layer_idx is not None else list(range(len(self._layers_ref))) def collect_intermediate(prompts): all_acts = {l: [] for l in layers} for prompt in prompts: if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) captured = {} hooks = [] for l in layers: def make_hook(layer_idx): def hook_fn(module, input, output): # down_proj input = gate * up (the intermediate activation) # input to down_proj is a tuple, first element is the tensor inp = input[0] if isinstance(input, tuple) else input captured[layer_idx] = inp[0, -1].detach().clone() return hook_fn h = self._layers_ref[l].mlp.down_proj.register_forward_hook(make_hook(l)) hooks.append(h) try: with torch.no_grad(): self.model(input_ids) finally: for h in hooks: h.remove() for l in layers: if l in captured: all_acts[l].append(captured[l]) return {l: torch.stack(acts) for l, acts in all_acts.items() if acts} pos_acts = collect_intermediate(positive_prompts) neg_acts = collect_intermediate(negative_prompts) per_layer_neurons = {} for l in layers: if l in pos_acts and l in neg_acts: diff = pos_acts[l].mean(dim=0) - neg_acts[l].mean(dim=0) # [d_mlp] result = {} for i in range(diff.shape[0]): v = diff[i].item() if abs(v) > 1e-8: result[i] = v per_layer_neurons[l] = dict(sorted(result.items(), key=lambda x: abs(x[1]), reverse=True)) return per_layer_neurons def decompose_cv_to_neurons( self, control_vector: torch.Tensor, layer_idx: int, ) -> Dict[int, float]: """Decompose a control vector into per-neuron contributions. Projects the control vector onto each neuron's output column in W_down. CV contribution of neuron i = dot(CV, W_down[:, i]) / ||W_down[:, i]|| Args: control_vector: [d_model] control vector at this layer layer_idx: Which layer's MLP to decompose against Returns: Dict[neuron_idx, projection_weight] sorted by |weight| """ W_down = self._layers_ref[layer_idx].mlp.down_proj.weight # [d_model, d_mlp] # Each column of W_down is a neuron's output direction # Project CV onto each column cv = control_vector.float() W = W_down.float() # projections[i] = dot(cv, W[:, i]) = how much neuron i contributes to CV direction projections = torch.matmul(cv, W) # [d_mlp] # Normalize by column norms for interpretability col_norms = torch.norm(W, dim=0) # [d_mlp] normalized = projections / (col_norms + 1e-8) result = {} for i in range(projections.shape[0]): if abs(normalized[i].item()) > 1e-6: result[i] = normalized[i].item() return dict(sorted(result.items(), key=lambda x: abs(x[1]), reverse=True)) def compare_circuit_to_cv( self, circuit: Circuit, control_vectors: Dict[int, torch.Tensor], top_k: int = 50, verbose: bool = True, ) -> Dict[str, float]: """Compare CNA neuron circuit to CAA control vector decomposition. For each layer, compute how much of the control vector's variance is explained by the circuit neurons. Args: circuit: Neuron circuit from discover_circuit control_vectors: From compute_control_vector top_k: Number of top CV neurons to compare verbose: Print comparison Returns: Dict with overlap metrics """ circuit_by_layer = circuit.unique_neurons() total_overlap = 0 total_cv_neurons = 0 total_variance_explained = 0.0 n_layers = 0 for layer_idx, cv in control_vectors.items(): cv_decomp = self.decompose_cv_to_neurons(cv, layer_idx) top_cv = list(cv_decomp.keys())[:top_k] circuit_neurons = circuit_by_layer.get(layer_idx, set()) overlap = len(set(top_cv) & circuit_neurons) total_overlap += overlap total_cv_neurons += min(top_k, len(cv_decomp)) # Variance explained: sum of squared projections for circuit neurons all_proj_sq = sum(v ** 2 for v in cv_decomp.values()) circuit_proj_sq = sum(cv_decomp.get(n, 0) ** 2 for n in circuit_neurons) var_expl = circuit_proj_sq / (all_proj_sq + 1e-8) if all_proj_sq > 1e-8 else 0 if verbose and circuit_neurons: print(f" L{layer_idx:2d}: {len(circuit_neurons)} circuit neurons, " f"{overlap}/{min(top_k, len(cv_decomp))} overlap with top-{top_k} CV, " f"variance_explained={var_expl:.4f}") total_variance_explained += var_expl n_layers += 1 # Rank correlation: do circuit attribution ranks match CV decomposition ranks? # Flatten both to neuron lists circuit_ranked = [(n.layer, n.neuron, abs(a)) for n, a in circuit.top(200)] cv_ranked = [] for l, cv in control_vectors.items(): decomp = self.decompose_cv_to_neurons(cv, l) for neuron, weight in list(decomp.items())[:top_k]: cv_ranked.append((l, neuron, abs(weight))) cv_ranked.sort(key=lambda x: x[2], reverse=True) # Compute overlap at top-50 circuit_set = {(l, n) for l, n, _ in circuit_ranked[:50]} cv_set = {(l, n) for l, n, _ in cv_ranked[:50]} top50_overlap = len(circuit_set & cv_set) metrics = { "total_overlap": total_overlap, "total_cv_neurons_checked": total_cv_neurons, "mean_variance_explained": total_variance_explained / max(n_layers, 1), "top50_overlap": top50_overlap, } if verbose: print(f"\n Total overlap: {total_overlap}/{total_cv_neurons}") print(f" Mean variance explained: {metrics['mean_variance_explained']:.4f}") print(f" Top-50 neuron overlap (circuit vs CV): {top50_overlap}/50") return metrics def steer_and_generate( self, prompt: str, circuit: Circuit, multiplier: float = 0.0, max_new_tokens: int = 50, all_positions: bool = True, use_chat_template: bool = True, ) -> str: """Generate text with neuron steering applied. multiplier=0.0 → ablate circuit neurons (suppress behavior) multiplier=1.0 → no change (baseline) multiplier=2.0 → amplify circuit neurons (enhance behavior) """ if use_chat_template: formatted = self._format_prompt(prompt) else: formatted = prompt input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) with steer_neurons(self.model, circuit.neurons, multiplier, all_positions): with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, ) return self.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) def generate(self, prompt: str, max_new_tokens: int = 50, use_chat_template: bool = True) -> str: """Normal generation without steering.""" if use_chat_template: formatted = self._format_prompt(prompt) else: formatted = prompt input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, ) return self.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) def next_token_probs( self, prompt: str, tokens: List[str], circuit: Optional[Circuit] = None, multiplier: float = 1.0, seed_response: str = "", use_chat_template: bool = True, ) -> Dict[str, float]: """Get next-token probabilities for specific tokens. Useful for measuring steering effects on specific outputs. """ if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) ctx = steer_neurons(self.model, circuit.neurons, multiplier) if circuit else nullcontext() with ctx: with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits[0, -1] probs = F.softmax(logits, dim=-1) result = {} for token in tokens: tid = self.tokenizer.encode(token, add_special_tokens=False)[-1] result[token] = probs[tid].item() return result def compute_mean_activations( self, prompts: Optional[List[str]] = None, seed_response: str = "", use_chat_template: bool = True, ) -> Dict[int, torch.Tensor]: """Compute mean MLP neuron activations across prompts. Args: prompts: List of prompts to compute mean from. If None, uses 20 diverse prompts (less accurate but works as fallback). seed_response: Seed response to append (for chat template). use_chat_template: Whether to use chat template formatting. Returns: Dict mapping layer_idx -> mean activation tensor (intermediate_size,) """ if prompts is None: prompts = [ "The capital of France is", "Once upon a time there was a", "The best programming language is", "In the year 2024, the world", "How do I bake a cake?", "What is photosynthesis?", "The CEO of Apple is", "My favorite color is", "The largest ocean on Earth is", "The speed of light is approximately", "In machine learning, a neural network", "The president of the United States", "Water freezes at a temperature of", "The meaning of life is", "To solve this math problem,", "The Great Wall of China was", "An electron has a charge of", "The chemical formula for water is", "Yesterday I went to the", "The key to the cabinets", ] use_chat_template = False # raw prompts, no template # Accumulate activations across all prompts and ALL positions mean_acts: Dict[int, torch.Tensor] = {} total_tokens = 0 for prompt in prompts: if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) seq_len = input_ids.shape[1] layer_acts = {} hooks = [] for i, layer in enumerate(self._layers_ref): def make_hook(layer_idx): def hook_fn(module, args): # Capture ALL positions: args[0] shape (1, seq_len, intermediate_size) # Sum over seq_len dimension for running average layer_acts[layer_idx] = args[0][0].detach().sum(dim=0) # (intermediate_size,) return hook_fn h = layer.mlp.down_proj.register_forward_pre_hook(make_hook(i)) hooks.append(h) try: with torch.no_grad(): self.model(input_ids) finally: for h in hooks: h.remove() for layer_idx, act_sum in layer_acts.items(): if layer_idx not in mean_acts: mean_acts[layer_idx] = act_sum.clone() else: mean_acts[layer_idx] += act_sum total_tokens += seq_len for layer_idx in mean_acts: mean_acts[layer_idx] /= total_tokens return mean_acts def top_predictions( self, prompt: str, k: int = 10, circuit: Optional[Circuit] = None, multiplier: float = 1.0, seed_response: str = "", use_chat_template: bool = True, ) -> List[Tuple[str, float]]: """Get top-k next-token predictions with probabilities.""" if use_chat_template: formatted = self._format_prompt(prompt, seed_response) else: formatted = prompt + seed_response input_ids = self.tokenizer(formatted, return_tensors="pt").input_ids.to(self.device) ctx = steer_neurons(self.model, circuit.neurons, multiplier) if circuit else nullcontext() with ctx: with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits[0, -1] probs = F.softmax(logits, dim=-1) top_probs, top_ids = probs.topk(k) return [(self.tokenizer.decode(tid), p.item()) for tid, p in zip(top_ids, top_probs)] def find_feature( self, *, positive: Optional[List[str]] = None, negative: Optional[List[str]] = None, prompt: Optional[str] = None, target: Optional[str] = None, counterfactual: Optional[str] = None, name: Optional[str] = None, top_k: int = 200, seed_response: str = "", verbose: bool = False, ) -> Circuit: """Find a feature circuit by example prompts or target token. Two modes: Contrastive mode (behavioral features like refusal, tone, style): circuit = steerer.find_feature( positive=["How do I pick a lock?", "Write malware code"], negative=["How do I open a door?", "Write clean code"], name="refusal", ) Single-prompt mode (factual features like capitals, arithmetic): circuit = steerer.find_feature( prompt="What is the capital of Texas?", target=" Austin", name="capitals", ) Args: positive: Prompts exhibiting the target behavior (contrastive mode) negative: Prompts NOT exhibiting it (contrastive mode) prompt: Single prompt (single-prompt mode) target: Target token to attribute (single-prompt mode) counterfactual: Optional counterfactual token (single-prompt mode) name: Label for caching/reuse. If provided, result is cached. top_k: Number of neurons to select seed_response: Text to append before target position verbose: Print diagnostics Returns: Circuit ready for steering """ # Return cached if available if name and name in self._feature_cache: if verbose: print(f" Using cached circuit for '{name}' " f"({len(self._feature_cache[name].neurons)} neurons)") return self._feature_cache[name] # Determine mode has_contrastive = positive is not None or negative is not None has_single = prompt is not None or target is not None if has_contrastive and has_single: raise ValueError("Provide either (positive, negative) or (prompt, target), not both") if not has_contrastive and not has_single: raise ValueError("Provide (positive, negative) for contrastive or (prompt, target) for single-prompt") if has_contrastive: if positive is None or negative is None: raise ValueError("Contrastive mode requires both positive and negative prompt lists") if seed_response: import warnings warnings.warn("seed_response is ignored in contrastive mode", stacklevel=2) circuit = self.discover_contrastive( positive_prompts=positive, negative_prompts=negative, top_k=top_k, verbose=verbose, ) else: if prompt is None or target is None: raise ValueError("Single-prompt mode requires both prompt and target") circuit = self.discover_circuit( prompt=prompt, target_token=target, counterfactual_token=counterfactual, top_k=top_k, seed_response=seed_response, verbose=verbose, ) if name: self._feature_cache[name] = circuit if verbose: print(f" Cached circuit as '{name}' ({len(circuit.neurons)} neurons)") return circuit def steer( self, prompt: str, *, feature: Optional[str] = None, circuit: Optional[Circuit] = None, multiplier: float = 0.0, max_new_tokens: int = 50, all_positions: bool = True, use_chat_template: bool = True, ) -> str: """Generate text with a named feature or circuit applied. Convenience wrapper around steer_and_generate that uses cached features. Examples: # Using a previously discovered feature by name steerer.find_feature(prompt="Capital of Texas?", target=" Austin", name="capitals") output = steerer.steer("Capital of Ohio?", feature="capitals", multiplier=0.0) # Using a circuit directly output = steerer.steer("Capital of Ohio?", circuit=my_circuit, multiplier=2.0) Args: prompt: The prompt to generate from feature: Name of a cached feature (from find_feature with name=) circuit: Circuit object to use directly (alternative to feature name) multiplier: 0.0=ablate, 1.0=baseline, 2.0=amplify max_new_tokens: Max tokens to generate all_positions: Apply steering at all positions (not just circuit positions) use_chat_template: Format prompt for instruct models Returns: Generated text with steering applied """ if feature is not None and circuit is not None: raise ValueError("Provide either feature name or circuit, not both") if feature is None and circuit is None: raise ValueError("Provide either feature (name string) or circuit (Circuit object)") if feature is not None: if feature not in self._feature_cache: available = list(self._feature_cache.keys()) raise KeyError( f"Feature '{feature}' not found. " f"Available: {available}. Use find_feature() first." ) circuit = self._feature_cache[feature] return self.steer_and_generate( prompt=prompt, circuit=circuit, multiplier=multiplier, max_new_tokens=max_new_tokens, all_positions=all_positions, use_chat_template=use_chat_template, ) # ============================================================ # Interactive REPL # ============================================================ def interactive(self): """Launch interactive REPL for live neuron circuit exploration. Commands: prompt — Run a prompt, show output discover [target] — Find circuit (auto-detects target if omitted) ablate [spec] — Ablate neurons (L23/N8079, top10, all) amplify [spec] [mult] — Amplify neurons (default 2.0x) sweep [m1 m2 ...] — Multiplier sweep top [k] — Top-k next-token predictions save — Save circuit to file load [name] — Load circuit (no arg = list available) multiplier [value] — Get/set steering multiplier for 'top' info — Show current state quit / exit — Exit REPL """ import cmd import os import re steerer = self class NeuronREPL(cmd.Cmd): intro = ( "\n" "===== Neuron Steering REPL =====\n" f"Model: {steerer.model_name}\n" f"Blacklist: {len(steerer.blacklist)} universal neurons\n" "Type 'help' for commands, 'quit' to exit.\n" ) prompt = "neuron> " def __init__(self): super().__init__() self._prompt = None self._prompt_is_formatted = False # True if prompt already has chat template self._circuit = None self._graph = None self._multiplier = 1.0 self._saved = {} self._last_output = None # ---- prompt ---- def do_prompt(self, arg): """prompt — Run a prompt through the model and show output.""" if not arg.strip(): if self._prompt: print(f"Current prompt: {self._prompt}") else: print("Usage: prompt ") return self._prompt = arg.strip() self._prompt_is_formatted = False self._circuit = None self._graph = None try: uct = not self._prompt_is_formatted self._last_output = steerer.generate( self._prompt, max_new_tokens=100, use_chat_template=uct) print(f"\nOutput: {self._last_output}") except Exception as e: print(f"Error: {e}") # ---- discover ---- def do_discover(self, arg): """discover [target_token] — Discover circuit for current prompt.""" if not self._prompt: print("Set a prompt first: prompt ") return target = arg.strip() if arg.strip() else None uct = not self._prompt_is_formatted try: if target is None: preds = steerer.top_predictions(self._prompt, k=1, use_chat_template=uct) if preds: target = preds[0][0] print(f"Auto-target: '{target}' (p={preds[0][1]:.4f})") else: print("Could not auto-detect target. Provide one: discover ") return self._circuit = steerer.discover_circuit( self._prompt, target, top_k=200, filter_bos=True, verbose=False, use_chat_template=uct, ) self._graph = None print(f"\n{self._circuit.summary()}") print(f"\nTop 10 neurons:") for nidx, attr in self._circuit.top(10): print(f" L{nidx.layer:2d}/N{nidx.neuron:5d} (pos {nidx.position:2d}) attr={attr:+.6f}") except Exception as e: print(f"Error: {e}") # ---- ablate ---- def do_ablate(self, arg): """ablate [L/N | top | all] — Ablate neurons and regenerate.""" if not self._prompt: print("Set a prompt first.") return if not self._circuit: print("Discover a circuit first.") return try: circuit = self._select_neurons(arg.strip(), "ablate") if circuit is None: return uct = not self._prompt_is_formatted output = steerer.steer_and_generate( self._prompt, circuit, multiplier=0.0, max_new_tokens=100, use_chat_template=uct, ) print(f"\nAblated output (x0.0): {output}") except Exception as e: print(f"Error: {e}") # ---- amplify ---- def do_amplify(self, arg): """amplify [L/N | top | all] [multiplier] — Amplify neurons.""" if not self._prompt: print("Set a prompt first.") return if not self._circuit: print("Discover a circuit first.") return try: parts = arg.strip().split() multiplier = 2.0 neuron_spec = "" # First non-float arg is neuron spec, last float is multiplier non_floats = [] for p in parts: try: multiplier = float(p) except ValueError: non_floats.append(p) if len(non_floats) > 1: print(f"Warning: using last spec '{non_floats[-1]}', ignoring {non_floats[:-1]}") neuron_spec = non_floats[-1] if non_floats else "" circuit = self._select_neurons(neuron_spec, "amplify") if circuit is None: return uct = not self._prompt_is_formatted output = steerer.steer_and_generate( self._prompt, circuit, multiplier=multiplier, max_new_tokens=100, use_chat_template=uct, ) print(f"\nAmplified output (x{multiplier}): {output}") except Exception as e: print(f"Error: {e}") # ---- sweep ---- def do_sweep(self, arg): """sweep [m1 m2 ...] — Multiplier sweep over current circuit.""" if not self._prompt: print("Set a prompt first.") return if not self._circuit: print("Discover a circuit first.") return parts = arg.strip().split() if not parts: parts = ["0.0", "0.5", "1.0", "1.5", "2.0"] try: multipliers = [float(m) for m in parts] except ValueError: print("Usage: sweep ... (e.g., sweep 0.0 0.5 1.0 2.0)") return try: uct = not self._prompt_is_formatted for m in multipliers: output = steerer.steer_and_generate( self._prompt, self._circuit, multiplier=m, max_new_tokens=100, use_chat_template=uct, ) print(f" x{m}: {output}") except Exception as e: print(f"Error: {e}") # ---- top ---- def do_top(self, arg): """top [k] — Show top-k next-token predictions.""" k = 10 if arg.strip(): try: k = int(arg.strip()) except ValueError: print("Usage: top [k]") return if not self._prompt: print("Set a prompt first.") return uct = not self._prompt_is_formatted try: preds = steerer.top_predictions( self._prompt, k=k, circuit=self._circuit, multiplier=self._multiplier, use_chat_template=uct, ) print(f"\nTop-{k} predictions (multiplier={self._multiplier}):") for tok, prob in preds: bar = "#" * int(prob * 50) print(f" {prob:.4f} {bar} '{tok}'") except Exception as e: print(f"Error: {e}") # ---- save / load ---- def do_save(self, arg): """save — Save current circuit to file.""" name = arg.strip() if not name: print("Usage: save ") return if not self._circuit: print("No circuit to save. Run discover first.") return try: circuits_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "circuits" ) os.makedirs(circuits_dir, exist_ok=True) path = os.path.join(circuits_dir, f"{name}.json") self._circuit.save(path) self._saved[name] = self._circuit print(f"Saved to {path}") except Exception as e: print(f"Error: {e}") def do_load(self, arg): """load [name] — Load a saved circuit (no arg lists available).""" name = arg.strip() circuits_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "circuits" ) if not name: if os.path.isdir(circuits_dir): files = [f[:-5] for f in os.listdir(circuits_dir) if f.endswith(".json")] if files: print(f"Available: {', '.join(sorted(files))}") else: print("No saved circuits.") else: print("No saved circuits.") return try: path = os.path.join(circuits_dir, f"{name}.json") self._circuit = Circuit.load(path) self._graph = None self._prompt = self._circuit.prompt self._prompt_is_formatted = True # already has chat template print(f"Loaded from {path}") print(f"\n{self._circuit.summary()}") except FileNotFoundError: print(f"'{name}' not found. Use 'load' to list available.") except Exception as e: print(f"Error: {e}") # ---- multiplier ---- def do_multiplier(self, arg): """multiplier [value] — Get/set the steering multiplier for 'top' command.""" if not arg.strip(): print(f"Current multiplier: {self._multiplier}") return try: self._multiplier = float(arg.strip()) print(f"Multiplier set to {self._multiplier}") except ValueError: print("Usage: multiplier ") # ---- info ---- def do_info(self, arg): """info — Show current REPL state.""" print(f"\nPrompt: {self._prompt or '(none)'}") n = len(self._circuit.neurons) if self._circuit else 0 print(f"Circuit: {n} neurons") if self._circuit: print(f" Target: {self._circuit.target_token}") print(f" LogitD: {self._circuit.total_logit_diff:.4f}") print(f"Multiplier: {self._multiplier}") saved = ', '.join(self._saved.keys()) if self._saved else '(none)' print(f"Saved: {saved}") # ---- quit / exit ---- def do_quit(self, arg): """quit — Exit the REPL.""" print("Bye!") return True def do_exit(self, arg): """exit — Exit the REPL.""" return self.do_quit(arg) do_EOF = do_quit # ---- helpers ---- def _select_neurons(self, spec, action): """Parse neuron spec: 'L23/N8079', 'top10', 'all', or '' (= all).""" if not spec or spec == "all": return self._circuit if spec.startswith("top"): try: k = int(spec[3:]) except ValueError: print(f"Usage: {action} top") return None top_neurons = self._circuit.top(k) return Circuit( neurons=dict(top_neurons), prompt=self._circuit.prompt, target_token=self._circuit.target_token, total_logit_diff=self._circuit.total_logit_diff, ) m = re.match(r"L(\d+)/N(\d+)", spec) if m: layer, neuron = int(m.group(1)), int(m.group(2)) matched = {n: a for n, a in self._circuit.neurons.items() if n.layer == layer and n.neuron == neuron} if not matched: print(f"Neuron L{layer}/N{neuron} not in current circuit.") return None return Circuit( neurons=matched, prompt=self._circuit.prompt, target_token=self._circuit.target_token, total_logit_diff=self._circuit.total_logit_diff, ) print(f"Unknown spec '{spec}'. Use: L/N, top, or all") return None def emptyline(self): pass def default(self, line): print(f"Unknown command: {line.split()[0]}. Type 'help' for commands.") repl = NeuronREPL() while True: try: repl.cmdloop() break # normal exit via quit/exit/EOF except KeyboardInterrupt: print("\n(Ctrl+C — command cancelled. Type 'quit' to exit)") repl.intro = "" continue