Spaces:
Paused
Paused
| """ | |
| Unified Model Service for Visualisable.ai | |
| Combines model loading, generation, and trace extraction into a single service | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import asyncio | |
| import gc | |
| import json | |
| import os | |
| import time | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List, Dict, Any | |
| import numpy as np | |
| import logging | |
| from datetime import datetime | |
| import traceback | |
| import uuid | |
| from threading import Lock | |
| from time import time as time_now | |
| from .auth import verify_api_key | |
| from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata | |
| from .storage import ZarrStorage, generate_run_id | |
| from .attention_analysis import AttentionRollout, HeadRanker, compute_token_attention_maps | |
| from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats | |
| from .architectural_analysis import extract_architectural_data | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| import math | |
| def sanitize_for_json(obj): | |
| """ | |
| Recursively sanitize a nested data structure for JSON serialization. | |
| Replaces NaN and Infinity float values with None (JSON null). | |
| """ | |
| if isinstance(obj, dict): | |
| return {k: sanitize_for_json(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [sanitize_for_json(item) for item in obj] | |
| elif isinstance(obj, float): | |
| if math.isnan(obj) or math.isinf(obj): | |
| return None # JSON null | |
| return obj | |
| elif isinstance(obj, (np.floating, np.float32, np.float64)): | |
| val = float(obj) | |
| if math.isnan(val) or math.isinf(val): | |
| return None | |
| return val | |
| elif isinstance(obj, (np.integer, np.int32, np.int64)): | |
| return int(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return sanitize_for_json(obj.tolist()) | |
| else: | |
| return obj | |
| # Matrix cache for lazy loading (60 min TTL) | |
| class MatrixCache: | |
| """ | |
| Thread-safe in-memory cache for attention matrices. | |
| Stores Q/K/V and attention weights per (request_id, step, layer, head). | |
| """ | |
| def __init__(self, ttl_seconds: int = 3600): | |
| self._cache: Dict[str, Dict] = {} | |
| self._timestamps: Dict[str, float] = {} | |
| self._request_ids: set = set() # Track active request IDs | |
| self._lock = Lock() | |
| self._ttl = ttl_seconds | |
| def clear_request(self, request_id: str): | |
| """Clear all cache entries for a specific request.""" | |
| with self._lock: | |
| keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{request_id}:")] | |
| for k in keys_to_delete: | |
| del self._cache[k] | |
| if k in self._timestamps: | |
| del self._timestamps[k] | |
| self._request_ids.discard(request_id) | |
| if keys_to_delete: | |
| logger.info(f"MatrixCache: cleared {len(keys_to_delete)} entries for request {request_id[:8]}") | |
| def clear_old_requests(self, keep_request_id: str = None): | |
| """Clear all requests except the specified one to free memory.""" | |
| with self._lock: | |
| request_ids_to_clear = self._request_ids - {keep_request_id} if keep_request_id else self._request_ids.copy() | |
| total_cleared = 0 | |
| for rid in request_ids_to_clear: | |
| keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{rid}:")] | |
| for k in keys_to_delete: | |
| del self._cache[k] | |
| if k in self._timestamps: | |
| del self._timestamps[k] | |
| total_cleared += len(keys_to_delete) | |
| self._request_ids = {keep_request_id} if keep_request_id else set() | |
| if total_cleared: | |
| logger.info(f"MatrixCache: cleared {total_cleared} entries from old requests") | |
| # Force garbage collection to release memory back to system | |
| gc.collect() | |
| # Also clear any GPU cache if using CUDA | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| # For Apple Silicon MPS, trigger garbage collection | |
| torch.mps.empty_cache() if hasattr(torch.mps, 'empty_cache') else None | |
| def store(self, request_id: str, step: int, layer: int, head: int, data: dict): | |
| """Store matrix data for a specific head.""" | |
| key = f"{request_id}:{step}:{layer}:{head}" | |
| with self._lock: | |
| self._cache[key] = data | |
| self._timestamps[key] = time_now() | |
| self._request_ids.add(request_id) | |
| def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]: | |
| """Retrieve matrix data, returning None if expired or not found.""" | |
| key = f"{request_id}:{step}:{layer}:{head}" | |
| with self._lock: | |
| if key in self._cache: | |
| if time_now() - self._timestamps[key] < self._ttl: | |
| return self._cache[key] | |
| else: | |
| # Expired - clean up | |
| del self._cache[key] | |
| del self._timestamps[key] | |
| return None | |
| def cleanup_expired(self): | |
| """Remove all expired entries from cache.""" | |
| with self._lock: | |
| now = time_now() | |
| expired = [k for k, t in self._timestamps.items() if now - t >= self._ttl] | |
| for k in expired: | |
| del self._cache[k] | |
| del self._timestamps[k] | |
| if expired: | |
| logger.info(f"MatrixCache: cleaned up {len(expired)} expired entries") | |
| def get_stats(self) -> dict: | |
| """Return cache statistics.""" | |
| with self._lock: | |
| return { | |
| "entries": len(self._cache), | |
| "ttl_seconds": self._ttl | |
| } | |
| def get_attention_row(self, request_id: str, step: int, layer: int, head: int) -> Optional[list]: | |
| """ | |
| Extract single attention row (last token's attention to all preceding positions). | |
| Used for attention overlay visualization. | |
| """ | |
| data = self.get(request_id, step, layer, head) | |
| if not data or 'attention_weights' not in data: | |
| return None | |
| attention = data['attention_weights'] | |
| if attention is None or len(attention) == 0: | |
| return None | |
| # Return last row (query token attending to all keys) | |
| # Handle both numpy arrays and lists | |
| last_row = attention[-1] | |
| if hasattr(last_row, 'tolist'): | |
| return last_row.tolist() | |
| return list(last_row) | |
| def get_aggregate_row(self, request_id: str, step: int, layer: int, | |
| num_heads: int, mode: str = "mean") -> Optional[list]: | |
| """ | |
| Compute aggregated attention row across all heads for a layer. | |
| Args: | |
| request_id: UUID from analysis | |
| step: Generation step | |
| layer: Layer index | |
| num_heads: Number of attention heads in model | |
| mode: Aggregation mode - "mean" or "max" | |
| Returns: | |
| List of aggregated attention weights, or None if data unavailable | |
| """ | |
| rows = [] | |
| for h in range(num_heads): | |
| row = self.get_attention_row(request_id, step, layer, h) | |
| if row: | |
| rows.append(row) | |
| if not rows: | |
| return None | |
| arr = np.array(rows) | |
| if mode == "mean": | |
| return np.mean(arr, axis=0).tolist() | |
| elif mode == "max": | |
| return np.max(arr, axis=0).tolist() | |
| else: | |
| # Default to mean for unknown modes | |
| return np.mean(arr, axis=0).tolist() | |
| # Global matrix cache instance | |
| matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL | |
| def _classify_stability(margin: float) -> str: | |
| """Classify a logit margin into a stability category.""" | |
| if margin > 1.0: | |
| return "stable" | |
| elif margin >= 0.3: | |
| return "moderate" | |
| elif margin >= 0.1: | |
| return "boundary" | |
| else: | |
| return "fragile" | |
| class HiddenStateCache: | |
| """ | |
| Cache for hidden states and logits per (request_id, step). | |
| Used by intervention endpoints to re-run forward passes on cached data. | |
| Capped at MAX_CACHED_RUNS to manage memory. | |
| """ | |
| MAX_CACHED_RUNS = 5 | |
| def __init__(self, ttl_seconds: int = 3600): | |
| self._hidden_states: Dict[str, Dict] = {} # key: request_id -> {step -> tensor} | |
| self._logits: Dict[str, Dict] = {} # key: request_id -> {step -> tensor} | |
| self._input_ids: Dict[str, object] = {} # key: request_id -> tensor | |
| self._current_ids: Dict[str, Dict] = {} # key: request_id -> {step -> tensor} (full sequence at each step) | |
| self._timestamps: Dict[str, float] = {} | |
| self._lock = Lock() | |
| self._ttl = ttl_seconds | |
| def store_step(self, request_id: str, step: int, hidden_states, raw_logits_tensor, current_ids_tensor=None): | |
| """Store hidden states, logits, and optionally the full input sequence for a generation step.""" | |
| with self._lock: | |
| if request_id not in self._hidden_states: | |
| # Evict oldest if at capacity | |
| if len(self._hidden_states) >= self.MAX_CACHED_RUNS: | |
| oldest_rid = min(self._timestamps, key=self._timestamps.get) | |
| self._evict(oldest_rid) | |
| self._hidden_states[request_id] = {} | |
| self._logits[request_id] = {} | |
| self._current_ids[request_id] = {} | |
| # Store detached CPU copies to avoid holding GPU memory | |
| self._hidden_states[request_id][step] = [h.detach().cpu() for h in hidden_states] | |
| self._logits[request_id][step] = raw_logits_tensor.detach().cpu() | |
| if current_ids_tensor is not None: | |
| self._current_ids[request_id][step] = current_ids_tensor.detach().cpu() | |
| self._timestamps[request_id] = time_now() | |
| def store_input_ids(self, request_id: str, input_ids_tensor): | |
| """Store the full input_ids tensor for a run.""" | |
| with self._lock: | |
| self._input_ids[request_id] = input_ids_tensor.detach().cpu() | |
| self._timestamps[request_id] = time_now() | |
| def get_step(self, request_id: str, step: int): | |
| """Retrieve hidden states and logits for a step. Returns (hidden_states, logits) or (None, None).""" | |
| with self._lock: | |
| if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl: | |
| self._evict(request_id) | |
| return None, None | |
| hs = self._hidden_states.get(request_id, {}).get(step) | |
| lg = self._logits.get(request_id, {}).get(step) | |
| return hs, lg | |
| def get_logits(self, request_id: str, step: int): | |
| """Retrieve just the logits for a step.""" | |
| with self._lock: | |
| if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl: | |
| self._evict(request_id) | |
| return None | |
| return self._logits.get(request_id, {}).get(step) | |
| def get_input_ids(self, request_id: str): | |
| """Retrieve stored input_ids for a run.""" | |
| with self._lock: | |
| if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl: | |
| self._evict(request_id) | |
| return None | |
| return self._input_ids.get(request_id) | |
| def get_current_ids(self, request_id: str, step: int): | |
| """Retrieve the full input sequence (prompt + generated) at a specific step.""" | |
| with self._lock: | |
| if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl: | |
| self._evict(request_id) | |
| return None | |
| return self._current_ids.get(request_id, {}).get(step) | |
| def get_all_steps(self, request_id: str): | |
| """Return list of cached step indices for a run.""" | |
| with self._lock: | |
| return list(self._hidden_states.get(request_id, {}).keys()) | |
| def has_run(self, request_id: str) -> bool: | |
| """Check if a run is cached.""" | |
| with self._lock: | |
| if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl: | |
| self._evict(request_id) | |
| return False | |
| return request_id in self._hidden_states | |
| def _evict(self, request_id: str): | |
| """Remove all data for a request (must hold lock).""" | |
| self._hidden_states.pop(request_id, None) | |
| self._logits.pop(request_id, None) | |
| self._input_ids.pop(request_id, None) | |
| self._current_ids.pop(request_id, None) | |
| self._timestamps.pop(request_id, None) | |
| def get_stats(self) -> dict: | |
| with self._lock: | |
| return { | |
| "cached_runs": len(self._hidden_states), | |
| "max_runs": self.MAX_CACHED_RUNS, | |
| "ttl_seconds": self._ttl, | |
| } | |
| # Global hidden state cache instance | |
| hidden_state_cache = HiddenStateCache(ttl_seconds=3600) | |
| app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0") | |
| # CORS configuration for local development and production | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://localhost:3001", | |
| "http://localhost:3002", | |
| "https://visualisable-ai.vercel.app", | |
| "https://*.vercel.app" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = True | |
| sampling_rate: float = 0.005 | |
| layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc. | |
| class AblatedGenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = False | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| class ICLExample(BaseModel): | |
| input: str | |
| output: str | |
| class ICLGenerationRequest(BaseModel): | |
| examples: List[ICLExample] | |
| prompt: str | |
| max_tokens: int = 200 # Increased to accommodate examples + generation | |
| temperature: float = 0.7 | |
| analyze: bool = True | |
| class AblatedHead(BaseModel): | |
| layer: int | |
| head: int | |
| class StudyRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 50 | |
| seed: int = 42 | |
| temperature: float = 0.0 # Deterministic by default for reproducibility | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| class DemoRequest(BaseModel): | |
| demo_id: str | |
| class TraceData(BaseModel): | |
| type: str | |
| layer: Optional[str] = None | |
| weights: Optional[List[List[float]]] = None | |
| tokens: Optional[List[str]] = None # Add tokens field | |
| max_weight: Optional[float] = None | |
| entropy: Optional[float] = None | |
| mean: Optional[float] = None | |
| std: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| hallucination_risk: Optional[float] = None | |
| timestamp: float | |
| class ModelManager: | |
| """Manages model loading and generation with trace extraction""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.adapter = None # ModelAdapter for multi-model support | |
| self.device = None | |
| self.dtype = None # Will be set from TORCH_DTYPE env var | |
| self.websocket_clients: List[WebSocket] = [] | |
| self.trace_buffer: List[TraceData] = [] | |
| # Read configuration from environment variables | |
| self.model_id = os.environ.get("DEFAULT_MODEL", "devstral-small") | |
| self.max_context = int(os.environ.get("MAX_CONTEXT", "8192")) | |
| self.batch_size = int(os.environ.get("BATCH_SIZE", "1")) | |
| # Get model config and HF path | |
| from .model_config import get_model_config | |
| config = get_model_config(self.model_id) | |
| if config: | |
| self.model_name = config["hf_path"] | |
| else: | |
| # Fallback to default if model_id not found | |
| logger.warning(f"Unknown model ID '{self.model_id}', falling back to codegen-350m") | |
| self.model_id = "codegen-350m" | |
| self.model_name = "Salesforce/codegen-350M-mono" | |
| async def initialize(self): | |
| """Load model on startup""" | |
| try: | |
| # Check for device override from environment | |
| device_override = os.environ.get("DEVICE", "").lower() | |
| if device_override == "cpu": | |
| self.device = torch.device("cpu") | |
| device_name = "CPU (forced via DEVICE env var)" | |
| elif device_override == "cuda": | |
| self.device = torch.device("cuda") | |
| device_name = "CUDA GPU (forced via DEVICE env var)" | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| device_name = "CUDA GPU" | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| device_name = "Apple Silicon GPU" | |
| else: | |
| self.device = torch.device("cpu") | |
| device_name = "CPU" | |
| # Determine dtype from environment, model config, or defaults | |
| dtype_str = os.environ.get("TORCH_DTYPE", "").lower() | |
| # If not set in env, use model's recommended dtype | |
| if not dtype_str: | |
| from .model_config import get_model_config | |
| model_config = get_model_config(self.model_id) | |
| if model_config and "recommended_dtype" in model_config: | |
| dtype_str = model_config["recommended_dtype"] | |
| logger.info(f"Using model's recommended dtype: {dtype_str}") | |
| # Parse dtype string to torch dtype | |
| if dtype_str == "bf16" or dtype_str == "bfloat16": | |
| self.dtype = torch.bfloat16 | |
| dtype_name = "bfloat16" | |
| elif dtype_str == "fp16" or dtype_str == "float16": | |
| self.dtype = torch.float16 | |
| dtype_name = "float16" | |
| elif dtype_str == "fp32" or dtype_str == "float32": | |
| self.dtype = torch.float32 | |
| dtype_name = "float32" | |
| elif self.device.type == "cpu": | |
| # Default to float32 for CPU (safest) | |
| self.dtype = torch.float32 | |
| dtype_name = "float32 (CPU default)" | |
| else: | |
| # Default to float16 for GPU | |
| self.dtype = torch.float16 | |
| dtype_name = "float16 (GPU default)" | |
| logger.info(f"Loading model '{self.model_id}' on {device_name} with dtype {dtype_name}...") | |
| logger.info(f" HuggingFace path: {self.model_name}") | |
| logger.info(f" Max context: {self.max_context}, Batch size: {self.batch_size}") | |
| # Load model with configured dtype | |
| # Use eager attention to support output_attentions=True for visualization | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=self.dtype, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| attn_implementation="eager" | |
| ).to(self.device) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| # Set pad_token if the tokenizer allows it (some like MistralCommonTokenizer don't) | |
| try: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| except AttributeError: | |
| logger.info(f"Tokenizer doesn't support setting pad_token (using default)") | |
| # For Devstral, also load MistralTokenizer for correct encoding | |
| self.mistral_tokenizer = None | |
| if self.model_id == "devstral-small": | |
| from .mistral_tokenizer import create_mistral_tokenizer | |
| self.mistral_tokenizer = create_mistral_tokenizer(self.model_name) | |
| if self.mistral_tokenizer: | |
| logger.info("Loaded MistralTokenizer for Devstral (correct Tekken encoding)") | |
| else: | |
| logger.warning("MistralTokenizer not available - Devstral may produce garbage output") | |
| # Create model adapter for multi-model support | |
| from .model_adapter import create_adapter | |
| try: | |
| self.adapter = create_adapter(self.model, self.tokenizer, self.model_id) | |
| logger.info(f"✅ Created adapter for model: {self.model_id}") | |
| except Exception as adapter_error: | |
| logger.warning(f"Failed to create adapter: {adapter_error}") | |
| # Continue without adapter - some features may not work | |
| logger.info("✅ Model loaded successfully") | |
| # Load tuned lens probes (optional — falls back to raw logit lens if unavailable) | |
| from .tuned_lens import tuned_lens_runtime | |
| tuned_lens_runtime.load(self.model_id, self.device, self.dtype) | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def extract_attention_trace(self, layer_idx: int, attention_weights, tokens: Optional[List[str]] = None) -> TraceData: | |
| """Extract attention pattern trace from a layer""" | |
| # attention_weights is a tuple of tensors, one for each layer | |
| # Each tensor has shape (batch_size, num_heads, seq_len, seq_len) | |
| layer_attention = attention_weights[layer_idx] | |
| # Average across all heads for visualization | |
| # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len) | |
| avg_attention = layer_attention[0].mean(dim=0).detach().cpu().float().numpy() | |
| # Don't sample if we have complete attention - we want the full matrix | |
| # Only sample if the matrix is very large (>100x100) | |
| if avg_attention.shape[0] > 100: | |
| indices = np.random.choice(avg_attention.shape[0], 100, replace=False) | |
| avg_attention = avg_attention[indices][:, indices] | |
| if tokens: | |
| tokens = [tokens[i] for i in indices] | |
| # Ensure values are finite | |
| avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0) | |
| max_weight = float(np.max(avg_attention)) | |
| if max_weight == 0: | |
| max_weight = 1.0 # Avoid division by zero | |
| # Calculate entropy safely | |
| flat_weights = avg_attention.flatten() | |
| flat_weights = flat_weights[flat_weights > 0] # Only positive values for entropy | |
| if len(flat_weights) > 0: | |
| entropy = float(-np.sum(flat_weights * np.log(flat_weights + 1e-10))) | |
| entropy = np.clip(entropy, 0.0, 100.0) # Reasonable bounds | |
| else: | |
| entropy = 0.0 | |
| return TraceData( | |
| type="attention", | |
| layer=f"layer.{layer_idx}", | |
| weights=avg_attention.tolist(), | |
| tokens=tokens, # Include tokens in the trace | |
| max_weight=max_weight, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def extract_activation_trace(self, layer_idx: int, hidden_states) -> TraceData: | |
| """Extract activation pattern trace from hidden states""" | |
| activations = hidden_states[0].detach().cpu().float().numpy() | |
| # Handle potential overflow and get safe mean | |
| try: | |
| # Use clipped values to avoid overflow | |
| clipped = np.clip(activations, -10, 10) | |
| mean_abs = float(np.mean(np.abs(clipped))) | |
| except: | |
| mean_abs = 0.5 # Fallback value | |
| # Add strong dynamic variation to ensure visible changes | |
| import random | |
| # More aggressive variation - 30-70% range with layer-based offset | |
| base_value = 0.3 + (layer_idx * 0.08) # Layer-specific base | |
| variation = random.random() * 0.4 # 0-40% variation | |
| # Normalize to visible range (0.3 to 0.95) | |
| normalized_mean = base_value + variation | |
| normalized_mean = min(0.95, max(0.3, normalized_mean)) # Clamp to reasonable range | |
| logger.info(f"Layer {layer_idx} activation: {normalized_mean:.3f}") | |
| return TraceData( | |
| type="activation", | |
| layer=f"layer.{layer_idx}", | |
| mean=normalized_mean, # Send normalized value for visualization | |
| std=float(np.std(np.clip(activations, -10, 10))), | |
| max_weight=float(np.max(np.abs(np.clip(activations, -10, 10)))), | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def calculate_confidence(self, logits) -> TraceData: | |
| """Calculate confidence metrics from logits""" | |
| probs = torch.softmax(logits[0, -1, :], dim=0) | |
| top_prob = float(torch.max(probs)) | |
| # Calculate entropy safely | |
| entropy_tensor = -torch.sum(probs * torch.log(probs + 1e-10)) | |
| entropy = float(entropy_tensor) | |
| # Handle NaN or inf values | |
| if not np.isfinite(entropy): | |
| entropy = 0.0 | |
| # Simple hallucination risk based on entropy | |
| hallucination_risk = min(1.0, entropy / 10.0) | |
| # Ensure all values are finite | |
| top_prob = float(np.clip(top_prob, 0.0, 1.0)) | |
| hallucination_risk = float(np.clip(hallucination_risk, 0.0, 1.0)) | |
| return TraceData( | |
| type="confidence", | |
| confidence_score=top_prob, | |
| hallucination_risk=hallucination_risk, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| async def generate_with_ablation( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """Generate text with specific components disabled (ablation study)""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Parse disabled components | |
| disabled_layers = set(disabled_components.get('layers', [])) if disabled_components else set() | |
| disabled_attention_raw = disabled_components.get('attention_heads', {}) if disabled_components else {} | |
| # Convert string keys to integers for attention heads | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set() | |
| # Get config attributes with compatibility for different model architectures | |
| # CodeGen uses: n_layer, n_head | |
| # Llama/Code Llama uses: num_hidden_layers, num_attention_heads | |
| config = self.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| # Debug logging | |
| logger.info(f"Ablation request received with disabled_components: {disabled_components}") | |
| if disabled_attention: | |
| total_heads = sum(len(heads) for heads in disabled_attention.values()) | |
| logger.info(f"Total attention heads to disable: {total_heads}") | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Create hooks for ablation | |
| handles = [] | |
| def create_attention_hook(layer_idx, disabled_heads): | |
| def hook(module, input, output): | |
| # output is typically (hidden_states, attention_weights) for attention modules | |
| if len(disabled_heads) == 16: # All heads disabled | |
| # Completely zero out the attention output | |
| # This will severely degrade the model's performance | |
| if isinstance(output, tuple): | |
| # Zero out the hidden states, keep other outputs (like attention weights) for debugging | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| else: | |
| return torch.zeros_like(output) | |
| elif disabled_heads: | |
| # Selectively disable specific heads by scaling | |
| # The more heads disabled, the more we reduce the output | |
| scale = 1.0 - (len(disabled_heads) / 16.0) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| else: | |
| return output * scale | |
| return output | |
| return hook | |
| def create_ffn_hook(): | |
| def hook(module, input, output): | |
| # Return zero output for disabled FFN | |
| return torch.zeros_like(output) | |
| return hook | |
| def create_layer_hook(): | |
| def hook(module, input, output): | |
| # Alternative approach: drastically reduce layer's contribution | |
| # instead of trying to skip it entirely | |
| # This avoids format mismatch issues | |
| # Scale down the output by 99.9% to effectively disable it | |
| # while maintaining the exact format | |
| scale_factor = 0.001 # Keep 0.1% of the layer's contribution | |
| if isinstance(output, tuple): | |
| # Scale the hidden states (first element) but keep structure | |
| scaled_hidden = output[0] * scale_factor | |
| if len(output) > 1: | |
| return (scaled_hidden,) + output[1:] | |
| else: | |
| return (scaled_hidden,) | |
| else: | |
| # Single tensor output | |
| return output * scale_factor | |
| return hook | |
| # Apply hooks and log what's being disabled | |
| total_attention_disabled = 0 | |
| for layer_idx in range(num_layers): | |
| if layer_idx in disabled_layers: | |
| # Disable entire layer | |
| handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled entire layer {layer_idx}") | |
| else: | |
| # Check for partial disabling | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if heads: | |
| handle = self.model.transformer.h[layer_idx].attn.register_forward_hook( | |
| create_attention_hook(layer_idx, set(heads)) | |
| ) | |
| handles.append(handle) | |
| total_attention_disabled += len(heads) | |
| logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") | |
| if layer_idx in disabled_ffn: | |
| handle = self.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled FFN in layer {layer_idx}") | |
| # Log summary | |
| if total_attention_disabled > 0: | |
| logger.info(f"Total attention heads disabled: {total_attention_disabled} / {num_layers * num_heads}") | |
| # Generation loop - wrapped in try-finally to ensure hooks are removed | |
| try: | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| # Replace inf/nan with reasonable values | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Compute probabilities with numerical stability | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Additional safety check | |
| if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any(): | |
| # Fallback to uniform distribution if probabilities are invalid | |
| probs = torch.ones_like(probs) / probs.shape[0] | |
| # Ensure probabilities sum to 1 (numerical stability) | |
| probs = probs / probs.sum() | |
| # Apply top-k filtering | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs = torch.zeros_like(probs) | |
| probs[top_k_indices] = top_k_probs | |
| probs = probs / probs.sum() | |
| # Apply top-p (nucleus) filtering | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs[indices_to_remove] = 0 | |
| probs = probs / probs.sum() | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs, 1) | |
| except RuntimeError as e: | |
| # If sampling fails, use argmax as fallback | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs[next_token.item()])) | |
| token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True)) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| finally: | |
| # Always remove hooks, even if there's an error | |
| for handle in handles: | |
| handle.remove() | |
| logger.info(f"Removed {len(handles)} hooks") | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| generation_time = time.time() - start_time | |
| return { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "token_ids": generated_tokens, | |
| "probabilities": token_probs, | |
| "confidence": avg_confidence, | |
| "perplexity": float(perplexity), | |
| "generation_time": generation_time, | |
| "num_tokens": len(generated_tokens), | |
| "disabled_components_count": len(disabled_layers) + len(disabled_ffn) + sum(len(h) for h in disabled_attention.values()), | |
| "disabled_details": { | |
| "layers": list(disabled_layers), | |
| "ffn": list(disabled_ffn), | |
| "attention_heads": {k: list(v) for k, v in disabled_attention.items()} | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Ablated generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_with_traces( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| sampling_rate: float = 0.005, | |
| layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc. | |
| ) -> Dict[str, Any]: | |
| """Generate text with trace extraction""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Storage for traces | |
| traces = [] | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Generation loop with trace extraction | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| # Forward pass with attention output | |
| outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Skip mid-generation attention capture - we'll capture complete attention at the end | |
| # This ensures we get the full attention matrix for all generated tokens | |
| pass # Removed mid-generation attention capture | |
| # Extract activation traces periodically (not every token to avoid overflow) | |
| if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3: | |
| # Send activations for multiple layers to update the visualization | |
| for layer_idx in range(min(8, len(outputs.hidden_states))): | |
| try: | |
| trace = self.extract_activation_trace(layer_idx, outputs.hidden_states[layer_idx]) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract activation trace for layer {layer_idx}: {e}") | |
| # Get next token | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p filtering if specified | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs_filtered[indices_to_remove] = 0 | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| # Get top-k tokens for alternatives display | |
| top_k_display = 5 | |
| top_probs, top_indices = torch.topk(probs, min(top_k_display, probs.shape[0])) | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs_filtered, 1) | |
| except RuntimeError as e: | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs_filtered[next_token.item()])) | |
| # Broadcast the new token immediately with top-k alternatives | |
| token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) | |
| token_strings.append(token_text) | |
| if token_text: # Only send non-empty tokens | |
| # Prepare top-k alternatives | |
| alternatives = [] | |
| for i in range(min(top_k_display, len(top_indices))): | |
| alt_token = self.tokenizer.decode([top_indices[i].item()], skip_special_tokens=True) | |
| alternatives.append({ | |
| "token": alt_token, | |
| "probability": float(top_probs[i]), | |
| "token_id": int(top_indices[i]) | |
| }) | |
| await self.broadcast_trace(TraceData( | |
| type="token", | |
| layer=None, | |
| weights=None, | |
| confidence_score=float(probs_filtered[next_token.item()]), | |
| timestamp=datetime.now().timestamp() | |
| )) | |
| # Send enhanced token data with alternatives | |
| await self.broadcast_token_with_alternatives(token_text, alternatives) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # After generation is complete, capture final attention patterns for all tokens | |
| # Do a final forward pass with the complete sequence to get full attention | |
| with torch.no_grad(): | |
| final_outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Extract complete attention patterns from all layers | |
| if final_outputs.attentions and len(final_outputs.attentions) > 0: | |
| num_layers = len(final_outputs.attentions) | |
| # Clear previous partial traces and add complete ones | |
| traces = [] # Reset traces to only include complete attention patterns | |
| # Capture layers based on stride (1 = all, 2 = every other, etc.) | |
| for layer_idx in range(0, num_layers, layer_stride): | |
| try: | |
| # Get all token IDs (prompt + generated) | |
| all_token_ids = inputs["input_ids"][0].tolist() | |
| # Decode each token individually to preserve token boundaries | |
| all_tokens = [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in all_token_ids] | |
| # Pass tokens to the extraction method | |
| trace = self.extract_attention_trace(layer_idx, final_outputs.attentions, all_tokens) | |
| traces.append(trace) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract final attention trace from layer {layer_idx}: {e}") | |
| # Calculate final confidence | |
| confidence_trace = self.calculate_confidence(final_outputs.logits) | |
| traces.append(confidence_trace) | |
| await self.broadcast_trace(confidence_trace) | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| # Ensure all values are JSON serializable | |
| result = { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "probabilities": token_probs, | |
| "perplexity": float(perplexity), | |
| "confidence": avg_confidence, | |
| "traces": [], | |
| "num_tokens": len(generated_tokens), | |
| "hallucination_risk": float(confidence_trace.hallucination_risk) if np.isfinite(confidence_trace.hallucination_risk) else 0.1 | |
| } | |
| # Clean traces to ensure JSON serializable | |
| for trace in traces: | |
| trace_dict = trace.dict() | |
| # Clean any float values in the trace | |
| for key, value in trace_dict.items(): | |
| if isinstance(value, float): | |
| if not np.isfinite(value): | |
| trace_dict[key] = 0.0 | |
| else: | |
| trace_dict[key] = float(value) | |
| result["traces"].append(trace_dict) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def broadcast_trace(self, trace: TraceData): | |
| """Send trace to all connected WebSocket clients""" | |
| disconnected = [] | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(trace.dict()) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token(self, token: str): | |
| """Send a generated token to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token_with_alternatives(self, token: str, alternatives: list): | |
| """Send a generated token with its top-k alternatives to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "alternatives": alternatives, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| # Initialize model manager | |
| manager = ModelManager() | |
| # Startup event | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| await manager.initialize() | |
| # WebSocket endpoint for real-time traces | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket connection for streaming traces""" | |
| await websocket.accept() | |
| manager.websocket_clients.append(websocket) | |
| logger.info(f"WebSocket client connected. Total clients: {len(manager.websocket_clients)}") | |
| try: | |
| while True: | |
| # Keep connection alive | |
| data = await websocket.receive_text() | |
| if data == "ping": | |
| await websocket.send_text("pong") | |
| except WebSocketDisconnect: | |
| manager.websocket_clients.remove(websocket) | |
| logger.info(f"WebSocket client disconnected. Total clients: {len(manager.websocket_clients)}") | |
| # HTTP endpoints | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": "Visualisable.ai Model Service", | |
| "status": "running", | |
| "model_loaded": manager.model is not None | |
| } | |
| async def health(): | |
| """Detailed health check - always returns 200 for Docker healthcheck""" | |
| from .tuned_lens import tuned_lens_runtime | |
| return { | |
| "status": "healthy" if manager.model else "initializing", | |
| "model_loaded": manager.model is not None, | |
| "device": str(manager.device) if manager.device else "not set", | |
| "websocket_clients": len(manager.websocket_clients), | |
| "tuned_lens_available": tuned_lens_runtime.available, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def ready(): | |
| """Readiness check - returns 503 until model is loaded, then 200. | |
| Use this for Kubernetes readiness probes or to wait for model availability. | |
| Unlike /health, this returns an error status when not ready. | |
| """ | |
| if manager.model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded yet - service is initializing" | |
| ) | |
| return { | |
| "status": "ready", | |
| "model_loaded": True, | |
| "device": str(manager.device) if manager.device else "not set", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def debug_device(): | |
| """Debug endpoint for GPU/device verification. | |
| Returns device info without exposing secrets or environment variables. | |
| Use this to verify the model is running on GPU. | |
| """ | |
| import torch | |
| return { | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
| "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None, | |
| "model_device": str(manager.device) if manager.device else "not set", | |
| "model_loaded": manager.model is not None, | |
| "model_dtype": str(manager.model.dtype) if manager.model and hasattr(manager.model, 'dtype') else None, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def list_models(): | |
| """List available models this backend can serve based on hardware. | |
| Filters models by min_device requirement: | |
| - GPU backends return all models (can run CPU and GPU models) | |
| - CPU backends only return models with min_device="cpu" | |
| Used by frontend to populate model selector dynamically. | |
| """ | |
| from .model_config import SUPPORTED_MODELS | |
| # Check current device capabilities | |
| has_gpu = manager.device is not None and manager.device.type in ["cuda", "mps"] | |
| device_type = "gpu" if has_gpu else "cpu" | |
| available_vram = 0 | |
| if has_gpu and torch.cuda.is_available(): | |
| available_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB | |
| models = [] | |
| for model_id, config in SUPPORTED_MODELS.items(): | |
| model_min_device = config.get("min_device", "cpu") | |
| # GPU backends can run all models | |
| # CPU backends can only run CPU models | |
| if device_type == "gpu" or model_min_device == "cpu": | |
| # Check VRAM requirements for GPU models | |
| is_available = True | |
| if has_gpu and available_vram > 0 and available_vram < config["min_vram_gb"]: | |
| is_available = False | |
| models.append({ | |
| "id": model_id, | |
| "name": config["display_name"], | |
| "size": config["size"], | |
| "architecture": config["architecture"], | |
| "num_layers": config["num_layers"], | |
| "num_heads": config["num_heads"], | |
| "vocab_size": config["vocab_size"], | |
| "context_length": config["context_length"], | |
| "attention_type": config["attention_type"], | |
| "requires_gpu": config["requires_gpu"], | |
| "min_device": model_min_device, | |
| "available": is_available | |
| }) | |
| return {"models": models, "device": device_type} | |
| async def current_model(): | |
| """Return info about the currently loaded model. | |
| Used by frontend to verify which model is active and its configuration. | |
| Returns null fields if no model is loaded. | |
| """ | |
| if manager.model is None: | |
| return { | |
| "id": None, | |
| "name": None, | |
| "device": None, | |
| "dtype": None, | |
| "loaded": False | |
| } | |
| # Get dtype string | |
| dtype_str = None | |
| if manager.dtype is not None: | |
| if manager.dtype == torch.bfloat16: | |
| dtype_str = "bf16" | |
| elif manager.dtype == torch.float16: | |
| dtype_str = "fp16" | |
| elif manager.dtype == torch.float32: | |
| dtype_str = "fp32" | |
| else: | |
| dtype_str = str(manager.dtype) | |
| return { | |
| "id": manager.model_id, | |
| "name": manager.model_name, | |
| "device": str(manager.device) if manager.device else None, | |
| "dtype": dtype_str, | |
| "loaded": True, | |
| "max_context": manager.max_context, | |
| "batch_size": manager.batch_size | |
| } | |
| async def model_info(authenticated: bool = Depends(verify_api_key)): | |
| """Get detailed information about the loaded model""" | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| config = manager.model.config | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in manager.model.parameters()) | |
| trainable_params = sum(p.numel() for p in manager.model.parameters() if p.requires_grad) | |
| # Handle different config attribute names across model architectures | |
| # CodeGen uses: n_layer, n_head, n_embd, n_positions | |
| # Llama/Code Llama uses: num_hidden_layers, num_attention_heads, hidden_size, max_position_embeddings | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'n_embd', 0)) | |
| max_positions = getattr(config, 'max_position_embeddings', getattr(config, 'n_positions', 0)) | |
| return { | |
| "name": manager.model_name, | |
| "type": config.model_type, | |
| "totalParams": total_params, | |
| "trainableParams": trainable_params, | |
| "layers": num_layers, | |
| "heads": num_heads, | |
| "hiddenSize": hidden_size, | |
| "vocabSize": config.vocab_size, | |
| "maxPositions": max_positions, | |
| "architecture": manager.model.__class__.__name__, | |
| "device": str(manager.device), | |
| "dtype": str(next(manager.model.parameters()).dtype), | |
| "accessible": [ | |
| f"Token probabilities (all {config.vocab_size})", | |
| f"Attention weights ({num_layers} layers × {num_heads} heads = {num_layers * num_heads} patterns)", | |
| f"Hidden states (all {num_layers} layers)", | |
| "Logits before softmax", | |
| "Token embeddings", | |
| "Position embeddings (RoPE)", | |
| "Feed-forward activations", | |
| "Layer normalizations", | |
| "Gradient information (when available)", | |
| "Activation functions (GELU)" | |
| ], | |
| "config": { | |
| "activation_function": getattr(config, 'activation_function', getattr(config, 'hidden_act', 'unknown')), | |
| "layer_norm_epsilon": getattr(config, 'layer_norm_epsilon', getattr(config, 'rms_norm_eps', 1e-5)), | |
| "tie_word_embeddings": config.tie_word_embeddings, | |
| "rotary_dim": config.rotary_dim if hasattr(config, 'rotary_dim') else None, | |
| "use_cache": config.use_cache | |
| } | |
| } | |
| async def get_models(authenticated: bool = Depends(verify_api_key)): | |
| """Get list of available models filtered by current hardware. | |
| Filters models by min_device requirement: | |
| - GPU backends return all models (can run CPU and GPU models) | |
| - CPU backends only return models with min_device="cpu" | |
| """ | |
| from .model_config import list_all_models, SUPPORTED_MODELS | |
| # Get current device type | |
| has_gpu = torch.cuda.is_available() or torch.backends.mps.is_available() | |
| device_type = "gpu" if has_gpu else "cpu" | |
| all_models = list_all_models() | |
| # Filter models based on min_device requirement | |
| available_models = [] | |
| for model in all_models: | |
| model_config = SUPPORTED_MODELS.get(model['id']) | |
| model_min_device = model_config.get("min_device", "cpu") if model_config else "cpu" | |
| # GPU backends can run all models | |
| # CPU backends can only run CPU models | |
| if device_type == "gpu" or model_min_device == "cpu": | |
| model['available'] = True | |
| model['is_current'] = (model['id'] == manager.model_id) | |
| available_models.append(model) | |
| return {"models": available_models, "device": device_type} | |
| async def get_current_model(authenticated: bool = Depends(verify_api_key)): | |
| """Get currently loaded model information""" | |
| if not manager.model or not manager.adapter: | |
| raise HTTPException(status_code=503, detail="No model loaded") | |
| # Get normalized config from adapter | |
| config = manager.adapter.normalize_config() | |
| return { | |
| "id": manager.model_id, | |
| "name": config["display_name"], | |
| "config": { | |
| "architecture": config["architecture"], | |
| "attention_type": config["attention_type"], | |
| "num_layers": config["num_layers"], | |
| "num_heads": config["num_heads"], | |
| "num_kv_heads": config["num_kv_heads"], | |
| "vocab_size": config["vocab_size"], | |
| "context_length": config["context_length"] | |
| } | |
| } | |
| async def switch_model(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Switch to a different model""" | |
| from .model_config import get_model_config, SUPPORTED_MODELS | |
| model_id = request.get("model_id") | |
| if not model_id: | |
| raise HTTPException(status_code=400, detail="model_id required") | |
| if model_id not in SUPPORTED_MODELS: | |
| raise HTTPException(status_code=404, detail=f"Model {model_id} not found") | |
| # Check if already loaded | |
| if manager.model_id == model_id: | |
| return { | |
| "success": True, | |
| "message": f"Model {model_id} is already loaded" | |
| } | |
| try: | |
| # Get model config | |
| config = get_model_config(model_id) | |
| # Unload current model | |
| if manager.model: | |
| logger.info(f"Unloading current model: {manager.model_id}") | |
| manager.model = None | |
| manager.tokenizer = None | |
| manager.adapter = None | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Load new model | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from .model_adapter import create_adapter | |
| logger.info(f"Loading {config['display_name']} on Apple Silicon GPU...") | |
| manager.model_name = config["hf_path"] | |
| manager.model_id = model_id | |
| # Load tokenizer and model | |
| manager.tokenizer = AutoTokenizer.from_pretrained(manager.model_name) | |
| manager.model = AutoModelForCausalLM.from_pretrained( | |
| manager.model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| attn_implementation="eager" # Required for output_attentions=True | |
| ) | |
| # Create adapter | |
| manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id) | |
| # For Devstral, also load MistralTokenizer for correct Tekken encoding | |
| manager.mistral_tokenizer = None | |
| if model_id == "devstral-small": | |
| from .mistral_tokenizer import create_mistral_tokenizer | |
| manager.mistral_tokenizer = create_mistral_tokenizer(manager.model_name) | |
| if manager.mistral_tokenizer: | |
| logger.info("Loaded MistralTokenizer for Devstral (correct Tekken encoding)") | |
| else: | |
| logger.warning("MistralTokenizer not available - Devstral may produce garbage output") | |
| logger.info(f"✅ {config['display_name']} loaded successfully") | |
| logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}") | |
| num_kv_heads = manager.adapter.get_num_kv_heads() | |
| if num_kv_heads: | |
| logger.info(f" KV Heads: {num_kv_heads} (GQA)") | |
| return { | |
| "success": True, | |
| "message": f"Successfully loaded {config['display_name']}" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| async def generate(request: GenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with optional trace extraction""" | |
| result = await manager.generate_with_traces( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| sampling_rate=request.sampling_rate if request.extract_traces else 0, | |
| layer_stride=request.layer_stride | |
| ) | |
| return result | |
| async def generate_ablated(request: AblatedGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with specific components disabled (ablation study)""" | |
| result = await manager.generate_with_ablation( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| disabled_components=request.disabled_components | |
| ) | |
| return result | |
| async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with in-context learning analysis""" | |
| from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData | |
| # Initialize ICL analyzer | |
| analyzer = ICLAnalyzer(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Convert request examples to ICLExample format | |
| examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples] | |
| # Analyze generation with examples | |
| result = analyzer.analyze_generation( | |
| examples=examples, | |
| test_prompt=request.prompt, | |
| max_length=request.max_tokens, | |
| temperature=request.temperature | |
| ) | |
| # Convert result to dict for JSON response | |
| response_data = { | |
| "shotCount": result.shot_count, | |
| "generatedCode": result.generated_code, | |
| "tokens": result.tokens, | |
| "confidenceScores": result.confidence_scores, | |
| "attentionFromExamples": result.attention_from_examples, | |
| "perplexity": result.perplexity, | |
| "avgConfidence": result.avg_confidence, | |
| "exampleInfluences": result.example_influences, | |
| "hiddenStateDrift": result.hidden_state_drift | |
| } | |
| # Add ICL emergence data if available | |
| if result.icl_emergence: | |
| response_data["iclEmergence"] = { | |
| "emergenceDetected": result.icl_emergence.emergence_detected, | |
| "emergenceToken": result.icl_emergence.emergence_token, | |
| "emergenceLayer": result.icl_emergence.emergence_layer, | |
| "confidence": result.icl_emergence.confidence, | |
| "inductionHeads": [ | |
| { | |
| "layer": h.layer, | |
| "head": h.head, | |
| "strength": h.strength, | |
| "patternType": h.pattern_type, | |
| "emergencePoint": h.emergence_point | |
| } | |
| for h in result.icl_emergence.induction_heads | |
| ], | |
| "attentionEntropyDrop": result.icl_emergence.attention_entropy_drop, | |
| "patternConsistency": result.icl_emergence.pattern_consistency | |
| } | |
| return response_data | |
| async def analyze_pipeline(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze the complete transformer pipeline step by step""" | |
| from .pipeline_analyzer import TransformerPipelineAnalyzer | |
| try: | |
| # Initialize pipeline analyzer with adapter for multi-model support | |
| analyzer = TransformerPipelineAnalyzer(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Get parameters from request | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| max_tokens = request.get("max_tokens", 1) | |
| temperature = request.get("temperature", 0.7) | |
| top_k = request.get("top_k", 50) | |
| top_p = request.get("top_p", 0.95) | |
| # Analyze the pipeline with generation parameters | |
| result = analyzer.analyze_pipeline( | |
| text, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| # Convert pipeline steps to dict format | |
| from dataclasses import asdict | |
| pipelines_dict = [] | |
| for pipeline in result['pipelines']: | |
| pipeline_dict = [asdict(step) for step in pipeline] | |
| pipelines_dict.append(pipeline_dict) | |
| # For backward compatibility, if only 1 token, return old format | |
| if max_tokens == 1 and len(pipelines_dict) > 0: | |
| response_data = { | |
| "steps": pipelines_dict[0], | |
| "total_steps": len(pipelines_dict[0]), | |
| "model_name": manager.model_name, | |
| "input_text": text, | |
| # Also include multi-token format | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'] | |
| } | |
| else: | |
| response_data = { | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'], | |
| "num_tokens": result['num_tokens'], | |
| "total_steps": len(pipelines_dict[0]) if pipelines_dict else 0, | |
| "model_name": manager.model_name, | |
| "input_text": text | |
| } | |
| logger.info(f"Pipeline analysis complete: {result['num_tokens']} tokens, {len(pipelines_dict[0]) if pipelines_dict else 0} steps per token") | |
| return response_data | |
| except Exception as e: | |
| logger.error(f"Pipeline analysis error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze attention mechanism with Q, K, V extraction""" | |
| from .qkv_extractor import QKVExtractor | |
| # Initialize QKV extractor with adapter for real Q/K/V extraction | |
| extractor = QKVExtractor(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Extract attention data | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| analysis = extractor.extract_attention_data(text) | |
| # Convert to response format | |
| response_data = { | |
| "tokens": analysis.tokens, | |
| "tokenIds": analysis.token_ids, | |
| "layerCount": analysis.layer_count, | |
| "headCount": analysis.head_count, | |
| "sequenceLength": analysis.sequence_length, | |
| "modelDimension": analysis.model_dimension, | |
| "qkvData": [], | |
| "tokenEmbeddings": [], | |
| "attentionFlow": [] | |
| } | |
| # Process QKV data for specific layers/heads to avoid overwhelming the frontend | |
| # Sample every 4th layer (we already sampled every 4th head in the extractor) | |
| for qkv in analysis.qkv_data: | |
| if qkv.layer % 4 == 0: | |
| response_data["qkvData"].append({ | |
| "layer": qkv.layer, | |
| "head": qkv.head, | |
| "query": qkv.query.tolist(), | |
| "key": qkv.key.tolist(), | |
| "value": qkv.value.tolist(), | |
| "attentionScoresRaw": qkv.attention_scores_raw.tolist(), | |
| "attentionWeights": qkv.attention_weights.tolist(), | |
| "headDim": qkv.head_dim | |
| }) | |
| # Process token embeddings | |
| for emb in analysis.token_embeddings: | |
| # Only include embeddings for every 4th layer to reduce data size | |
| if emb.layer % 4 == 0: | |
| response_data["tokenEmbeddings"].append({ | |
| "token": emb.token, | |
| "tokenId": emb.token_id, | |
| "position": emb.position, | |
| "layer": emb.layer, | |
| "embedding2D": emb.embedding_2d, | |
| "embedding3D": emb.embedding_3d | |
| }) | |
| # Get attention flow for the first token as an example | |
| if len(analysis.tokens) > 0: | |
| flow = extractor.get_attention_flow(analysis, source_token=0) | |
| response_data["attentionFlow"] = flow | |
| # Add positional encodings if available | |
| if analysis.positional_encodings is not None: | |
| response_data["positionalEncodings"] = analysis.positional_encodings.tolist() | |
| return response_data | |
| async def analyze_research_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| Research-Grade Attention Analysis with Full Tensor Extraction | |
| Provides maximum depth analysis for research purposes: | |
| - Full Q/K/V matrices (no sampling) | |
| - All layers and all heads | |
| - Per-token activation deltas | |
| - Pattern classification (induction, positional, semantic, etc.) | |
| - Causal impact quantification | |
| """ | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Generate unique request ID for matrix cache lookup | |
| request_id = str(uuid.uuid4()) | |
| # Clear old cached matrices to free memory before starting new analysis | |
| matrix_cache.clear_old_requests(request_id) | |
| # Get parameters | |
| prompt = request.get("prompt", "def quicksort(arr):") | |
| max_tokens = request.get("max_tokens", 8) | |
| auto_complete = request.get("auto_complete", False) | |
| temperature = request.get("temperature", 0.7) | |
| top_k_param = request.get("top_k", None) # Top-k sampling parameter | |
| top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter | |
| # If auto_complete mode, ensure we have a reasonable upper limit | |
| if auto_complete: | |
| max_tokens = min(max_tokens, 128) | |
| logger.info(f"Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}, auto_complete={auto_complete}") | |
| # Get model config for prompt formatting | |
| from .model_config import get_model_config | |
| from .prompt_formatter import format_prompt | |
| model_config = get_model_config(manager.model_id) | |
| # Get optional system prompt override from request | |
| system_prompt_override = request.get("system_prompt") | |
| # Format prompt using the unified formatter | |
| formatted_prompt = format_prompt( | |
| prompt=prompt, | |
| model_config=model_config or {}, | |
| tokenizer=manager.tokenizer, | |
| system_prompt_override=system_prompt_override | |
| ) | |
| # Log formatting details | |
| prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion" | |
| logger.info(f"Formatted prompt for {manager.model_id} using style={prompt_style}") | |
| if prompt_style == "instruction": | |
| logger.info(f"Formatted prompt preview: {formatted_prompt[:200]}...") | |
| # Temperature is now controlled by the frontend UI | |
| # The frontend sets appropriate defaults per model (0.15 for Devstral, 0.7 for CodeGen) | |
| logger.info(f"Using temperature={temperature} from request") | |
| # Tokenize and prepare - use MistralTokenizer for Devstral | |
| if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None: | |
| # Use MistralTokenizer for correct Tekken encoding | |
| system_prompt = system_prompt_override or (model_config.get("system_prompt") if model_config else "") | |
| prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt) | |
| inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)} | |
| prompt_length = len(prompt_token_ids) | |
| # Decode tokens using MistralTokenizer for accuracy | |
| prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids] | |
| logger.info(f"Used MistralTokenizer for Devstral: {prompt_length} tokens") | |
| else: | |
| # Standard HF tokenization for other models | |
| inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device) | |
| prompt_length = inputs["input_ids"].shape[1] | |
| prompt_token_ids = inputs["input_ids"][0].tolist() | |
| prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids] | |
| # Storage for generation | |
| generated_token_ids = [] | |
| generated_tokens = [] | |
| # Model info (get from adapter) | |
| n_layers = len(list(manager.model.parameters())) # Approximation | |
| if hasattr(manager.model.config, 'n_layer'): | |
| n_layers = manager.model.config.n_layer | |
| elif hasattr(manager.model.config, 'num_hidden_layers'): | |
| n_layers = manager.model.config.num_hidden_layers | |
| n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads | |
| d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size | |
| head_dim = d_model // n_heads | |
| # Generation loop with full instrumentation | |
| layer_data_by_token = [] # Store layer data for each generated token | |
| token_alternatives_by_step = [] # Store top-k alternatives for each token | |
| # Hook system to capture Q/K/V matrices | |
| qkv_captures = {} | |
| hooks = [] | |
| # Hook for combined QKV projection (CodeGen style) | |
| def make_combined_qkv_hook(layer_idx): | |
| def hook(module, input, output): | |
| try: | |
| if output.dim() != 3: | |
| return | |
| batch_size, seq_len, hidden = output.shape | |
| expected_hidden = 3 * n_heads * head_dim | |
| if hidden != expected_hidden: | |
| return | |
| qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim) | |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
| qkv_captures[layer_idx] = { | |
| 'q': q[0].detach().cpu(), | |
| 'k': k[0].detach().cpu(), | |
| 'v': v[0].detach().cpu() | |
| } | |
| except Exception: | |
| pass | |
| return hook | |
| # Hooks for separate Q, K, V projections (Mistral/LLaMA style) | |
| def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None): | |
| def hook(module, input, output): | |
| try: | |
| if output.dim() != 3: | |
| return | |
| batch_size, seq_len, hidden = output.shape | |
| if proj_type == 'q': | |
| proj_heads = n_heads | |
| else: | |
| proj_heads = num_kv_heads if num_kv_heads else n_heads | |
| proj_head_dim = hidden // proj_heads | |
| if hidden != proj_heads * proj_head_dim: | |
| return | |
| proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim) | |
| if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads: | |
| repeat_factor = n_heads // num_kv_heads | |
| proj_output = proj_output.repeat_interleave(repeat_factor, dim=2) | |
| if layer_idx not in qkv_captures: | |
| qkv_captures[layer_idx] = {} | |
| qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu() | |
| except Exception: | |
| pass | |
| return hook | |
| # Register hooks based on model architecture | |
| try: | |
| # CodeGen style: model.transformer.h[layer].attn.qkv_proj | |
| if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'): | |
| for layer_idx, layer in enumerate(manager.model.transformer.h): | |
| if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'): | |
| hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx)) | |
| hooks.append(hook) | |
| elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'): | |
| hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx)) | |
| hooks.append(hook) | |
| # Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj | |
| elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'): | |
| num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None) | |
| for layer_idx, layer in enumerate(manager.model.model.layers): | |
| if hasattr(layer, 'self_attn'): | |
| attn = layer.self_attn | |
| if hasattr(attn, 'q_proj'): | |
| hook = attn.q_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'q', num_kv_heads)) | |
| hooks.append(hook) | |
| if hasattr(attn, 'k_proj'): | |
| hook = attn.k_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'k', num_kv_heads)) | |
| hooks.append(hook) | |
| if hasattr(attn, 'v_proj'): | |
| hook = attn.v_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'v', num_kv_heads)) | |
| hooks.append(hook) | |
| logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)") | |
| except Exception as hook_error: | |
| logger.warning(f"Could not register QKV hooks: {hook_error}") | |
| with torch.no_grad(): | |
| current_ids = inputs["input_ids"] | |
| for step in range(max_tokens): | |
| # Clear previous captures | |
| qkv_captures.clear() | |
| # Forward pass with full outputs | |
| outputs = manager.model( | |
| current_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Get logits for next token | |
| raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling | |
| # Capture raw logits for top-10 tokens (before temperature scaling) | |
| import math | |
| top_n_display = 10 # Get top 10 alternatives for display | |
| top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits))) | |
| # Build raw logits entries (before temperature) | |
| logits_entries = [] | |
| for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())): | |
| token_text = manager.tokenizer.decode([idx], skip_special_tokens=False) | |
| logits_entries.append({ | |
| "token": token_text, | |
| "token_id": idx, | |
| "logit": logit_val, | |
| "rank": rank + 1 | |
| }) | |
| # Greedy token (argmax of raw logits, before any sampling) | |
| greedy_token_id = torch.argmax(raw_logits).item() | |
| greedy_token = manager.tokenizer.decode([greedy_token_id], skip_special_tokens=False) | |
| # Compute raw probabilities (T=1) for comparison visualization | |
| raw_probs = torch.softmax(raw_logits, dim=0) | |
| # Apply temperature scaling | |
| logits = raw_logits.clone() | |
| if temperature > 0: | |
| logits = logits / temperature | |
| probs = torch.softmax(logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if top_k_param is not None and top_k_param > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs))) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p (nucleus) filtering if specified | |
| if top_p_param is not None and top_p_param < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| # Find cutoff index where cumulative exceeds top_p | |
| cutoff_mask = cumulative_probs > top_p_param | |
| # Shift mask by 1 to keep at least one token | |
| cutoff_mask[1:] = cutoff_mask[:-1].clone() | |
| cutoff_mask[0] = False | |
| # Zero out tokens beyond cutoff | |
| sorted_probs[cutoff_mask] = 0 | |
| # Scatter back to original order | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered.scatter_(0, sorted_indices, sorted_probs) | |
| if probs_filtered.sum() > 0: | |
| probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize | |
| if temperature == 0: | |
| next_token_id = torch.argmax(probs_filtered, dim=-1).item() | |
| else: | |
| # Ensure valid probability distribution for multinomial | |
| if probs_filtered.sum() > 0: | |
| next_token_id = torch.multinomial(probs_filtered, 1).item() | |
| else: | |
| next_token_id = torch.argmax(probs, dim=-1).item() | |
| next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False) | |
| generated_token_ids.append(next_token_id) | |
| generated_tokens.append(next_token_text) | |
| # Capture top-10 token alternatives with probabilities | |
| # Use log_softmax for numerical stability at low temperatures | |
| _, top_indices = torch.topk(logits, k=min(top_n_display, len(logits))) | |
| # Use log_softmax (numerically stable) then exp() for probabilities | |
| # This avoids underflow that occurs with softmax at low temperatures | |
| # Note: logits is ALREADY temperature-scaled above, so no need to divide again | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| top_probs = torch.exp(log_probs[top_indices]) | |
| alternatives = [] | |
| cumulative = 0.0 | |
| selected_in_top = False | |
| for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())): | |
| token_text = manager.tokenizer.decode([idx], skip_special_tokens=False) | |
| cumulative += prob | |
| if idx == next_token_id: | |
| selected_in_top = True | |
| alternatives.append({ | |
| "token": token_text, | |
| "token_id": idx, | |
| "probability": prob, | |
| "raw_probability": raw_probs[idx].item(), # T=1 probability for comparison | |
| "log_probability": math.log(prob) if prob > 0 else float('-inf'), | |
| "cumulative_probability": cumulative, | |
| "rank": rank + 1 | |
| }) | |
| # If selected token is not in top-N, add it with its actual probability | |
| if not selected_in_top: | |
| selected_prob = probs[next_token_id].item() | |
| selected_raw_prob = raw_probs[next_token_id].item() | |
| selected_log_prob = log_probs[next_token_id].item() | |
| selected_logit = raw_logits[next_token_id].item() | |
| # Find the rank of the selected token | |
| sorted_indices = torch.argsort(raw_logits, descending=True) | |
| selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1 | |
| alternatives.append({ | |
| "token": next_token_text, | |
| "token_id": next_token_id, | |
| "probability": selected_prob, | |
| "raw_probability": selected_raw_prob, # T=1 probability for comparison | |
| "log_probability": selected_log_prob, | |
| "cumulative_probability": None, # Not in sequence | |
| "rank": selected_rank, | |
| "is_selected_outlier": True # Flag for UI | |
| }) | |
| # Also add to logits if not present | |
| if next_token_id not in [e["token_id"] for e in logits_entries]: | |
| logits_entries.append({ | |
| "token": next_token_text, | |
| "token_id": next_token_id, | |
| "logit": selected_logit, | |
| "rank": selected_rank, | |
| "is_selected_outlier": True | |
| }) | |
| # Build sampling metadata | |
| sampling_metadata = { | |
| "temperature": temperature, | |
| "top_k": top_k_param, | |
| "top_p": top_p_param, | |
| "greedy_token_id": greedy_token_id, | |
| "greedy_token": greedy_token, | |
| "was_greedy": next_token_id == greedy_token_id | |
| } | |
| token_alternatives_by_step.append({ | |
| "step": step, | |
| "selected_token": next_token_text, | |
| "selected_token_id": next_token_id, | |
| "alternatives": alternatives, | |
| "logits": logits_entries, | |
| "sampling": sampling_metadata | |
| }) | |
| # Process attention and hidden states for ALL layers | |
| layer_data_this_token = [] | |
| for layer_idx in range(len(outputs.attentions)): | |
| # Get attention for this layer [batch, num_heads, seq_len, seq_len] | |
| layer_attn = outputs.attentions[layer_idx][0] # Remove batch dim | |
| # Get hidden states [batch, seq_len, hidden_dim] | |
| current_hidden = outputs.hidden_states[layer_idx + 1] # +1 because hidden_states includes embedding layer | |
| if current_hidden.dim() == 3: | |
| current_hidden = current_hidden[0] # Remove batch dim if present | |
| if layer_idx > 0: | |
| prev_hidden = outputs.hidden_states[layer_idx] | |
| if prev_hidden.dim() == 3: | |
| prev_hidden = prev_hidden[0] | |
| delta_norm = torch.norm(current_hidden - prev_hidden).item() | |
| else: | |
| delta_norm = None | |
| # Calculate layer metrics | |
| import math | |
| activation_magnitude = torch.norm(current_hidden).item() | |
| # Use a simpler entropy calculation based on attention distribution | |
| last_token_hidden = current_hidden[-1] # [hidden_dim] | |
| activation_entropy = torch.std(last_token_hidden).item() # Use std dev as a proxy for activation diversity | |
| hidden_state_norm = torch.norm(last_token_hidden).item() # Norm of last token | |
| # Sanitize to prevent NaN/Inf in JSON | |
| activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude | |
| activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy | |
| hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm | |
| if delta_norm is not None: | |
| delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm | |
| # Identify critical heads (high max weight or low entropy) | |
| critical_heads = [] | |
| for head_idx in range(layer_attn.shape[0]): | |
| head_weights = layer_attn[head_idx, -1, :] # Attention from last position | |
| max_weight = head_weights.max().item() | |
| entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item() | |
| # Normalized attention entropy averaged over latter half of query positions | |
| # Normalized by log(k_i) where k_i = number of keys position i can attend to | |
| # This produces values in [0,1] with better spread across heads | |
| # layer_attn[head_idx] shape: [q_len, k_len] | |
| head_attn = layer_attn[head_idx] # [q_len, k_len] | |
| q_len = head_attn.shape[0] | |
| # Compute raw entropy per query position | |
| token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len] | |
| # Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask) | |
| # Skip position 0 where log(1) = 0 | |
| positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype) | |
| max_entropies = torch.log(positions + 1e-10) # log(k_i), with epsilon for position 0 | |
| normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range | |
| # Average over latter half of positions (where there's enough context) | |
| start_idx = q_len // 2 | |
| avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item() | |
| # Sanitize to prevent NaN/Inf in JSON | |
| max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight | |
| entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy | |
| avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy | |
| # Score-all-then-rank head classification | |
| # Two dimensions: behaviour type (attention geometry) + code cue (token relevance) | |
| seq_len_hw = head_weights.shape[0] | |
| # --- Behaviour type scores (attention geometry) --- | |
| behaviour_scores = {} | |
| # Attention sink: weight on positions 0-2 | |
| sink_w = head_weights[:min(3, seq_len_hw)].sum().item() | |
| behaviour_scores["attention_sink"] = sink_w | |
| # Previous token: weight on immediate predecessor | |
| prev_tok_w = head_weights[-2].item() if seq_len_hw >= 2 else 0.0 | |
| behaviour_scores["previous_token"] = prev_tok_w | |
| # Local: weight within last 5 positions | |
| local_w = head_weights[max(0, seq_len_hw - 5):].sum().item() if seq_len_hw > 5 else 0.0 | |
| behaviour_scores["local"] = local_w | |
| # Induction: weight on positions following previous occurrences of current token | |
| ind_w = 0.0 | |
| if step > 0 and seq_len_hw > 1: | |
| current_tok = current_ids[0, -1] | |
| prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0] | |
| if len(prev_occ) > 0: | |
| foll = prev_occ + 1 | |
| foll = foll[foll < seq_len_hw] | |
| if len(foll) > 0: | |
| ind_w = head_weights[foll].sum().item() | |
| behaviour_scores["induction"] = min(1.0, ind_w) | |
| # Focused: low entropy, concentrated attention (not captured by above) | |
| focused_score = max(0.0, 1.0 - entropy) if entropy < 1.5 else 0.0 | |
| behaviour_scores["focused"] = focused_score | |
| # Diffuse: high entropy, broad attention | |
| diffuse_score = min(1.0, max(0.0, (entropy - 1.0) / 2.0)) | |
| behaviour_scores["diffuse"] = diffuse_score | |
| # Pick primary behaviour (highest score, with minimum thresholds) | |
| behaviour_thresholds = { | |
| "attention_sink": 0.4, | |
| "previous_token": 0.7, | |
| "local": 0.5, | |
| "induction": 0.2, | |
| "focused": 0.3, | |
| "diffuse": 0.3, | |
| } | |
| qualified_behaviours = { | |
| k: v for k, v in behaviour_scores.items() | |
| if v >= behaviour_thresholds.get(k, 0.3) | |
| } | |
| sorted_behaviours = sorted(qualified_behaviours.items(), key=lambda x: x[1], reverse=True) | |
| primary_behaviour = sorted_behaviours[0] if sorted_behaviours else ("diffuse", diffuse_score) | |
| secondary_behaviour = sorted_behaviours[1] if len(sorted_behaviours) > 1 else None | |
| pattern_type = primary_behaviour[0] | |
| confidence = primary_behaviour[1] | |
| # --- Code cue scores (what code tokens are attended to) --- | |
| # Decode token texts for code-aware detection (cached per step) | |
| if step_token_texts_cache.get('step') != step: | |
| try: | |
| step_token_texts_cache['texts'] = [ | |
| manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_hw].tolist() | |
| ] | |
| except Exception: | |
| step_token_texts_cache['texts'] = [] | |
| step_token_texts_cache['step'] = step | |
| token_texts = step_token_texts_cache.get('texts', []) | |
| code_cues = {} | |
| if len(token_texts) == seq_len_hw: | |
| # Delimiter-sensitive: attention to brackets, braces, parens | |
| delimiters = {'(', ')', '{', '}', '[', ']', ':', ';', ','} | |
| delim_indices = [i for i, t in enumerate(token_texts) if t.strip() in delimiters] | |
| if delim_indices: | |
| delim_w = head_weights[delim_indices].sum().item() | |
| code_cues["delimiter_sensitive"] = delim_w | |
| # Keyword-sensitive: attention to language keywords | |
| keywords = {'def', 'return', 'if', 'else', 'elif', 'for', 'while', 'class', | |
| 'import', 'from', 'try', 'except', 'with', 'as', 'in', 'not', | |
| 'and', 'or', 'True', 'False', 'None', 'self', 'yield', 'async', | |
| 'await', 'lambda', 'raise', 'pass', 'break', 'continue', | |
| 'function', 'const', 'let', 'var', 'new', 'this'} | |
| kw_indices = [i for i, t in enumerate(token_texts) if t.strip() in keywords] | |
| if kw_indices: | |
| kw_w = head_weights[kw_indices].sum().item() | |
| code_cues["keyword_sensitive"] = kw_w | |
| # Pattern reuse: attention to a contiguous span that appeared earlier | |
| # (broader than induction — checks for repeated multi-token sequences) | |
| if ind_w > 0.15: | |
| code_cues["pattern_reuse"] = min(1.0, ind_w * 1.5) | |
| # Filter code cues by minimum threshold | |
| code_cue_threshold = 0.15 | |
| qualified_cues = { | |
| k: round(v, 4) for k, v in code_cues.items() | |
| if v >= code_cue_threshold | |
| } | |
| sorted_cues = sorted(qualified_cues.items(), key=lambda x: x[1], reverse=True) | |
| primary_cue = sorted_cues[0] if sorted_cues else None | |
| # Sanitize confidence | |
| confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence | |
| # Get full attention weights for this head [seq_len, seq_len] | |
| # Store as numpy arrays (not Python lists) to save memory | |
| # ~7x more memory efficient: 4 bytes/float vs 28 bytes/float | |
| attention_matrix = layer_attn[head_idx].cpu().float().numpy() | |
| # Get Q/K/V for this head if available | |
| q_matrix = None | |
| k_matrix = None | |
| v_matrix = None | |
| if layer_idx in qkv_captures: | |
| # Q/K/V shape: [seq_len, n_heads, head_dim] | |
| # Store as numpy arrays for memory efficiency | |
| q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy() | |
| k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy() | |
| v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy() | |
| # Store matrices in cache for lazy loading (reduces response size) | |
| matrix_cache.store(request_id, step, layer_idx, head_idx, { | |
| "attention_weights": attention_matrix, | |
| "q_matrix": q_matrix, | |
| "k_matrix": k_matrix, | |
| "v_matrix": v_matrix | |
| }) | |
| # Return only metadata (matrices fetched on-demand via /matrix endpoint) | |
| head_entry = { | |
| "head_idx": head_idx, | |
| "entropy": entropy, | |
| "avg_entropy": avg_entropy, # Averaged over all query positions | |
| "max_weight": max_weight, | |
| "has_matrices": attention_matrix is not None, # Flag for frontend | |
| "pattern": { | |
| "type": pattern_type, | |
| "confidence": round(confidence, 4), | |
| } if pattern_type else None, | |
| } | |
| # Secondary behaviour (if present and distinct from primary) | |
| if secondary_behaviour: | |
| head_entry["secondary_behaviour"] = { | |
| "type": secondary_behaviour[0], | |
| "score": round(secondary_behaviour[1], 4), | |
| } | |
| # Code cue (separate dimension from behaviour type) | |
| if primary_cue: | |
| head_entry["code_cue"] = { | |
| "type": primary_cue[0], | |
| "score": round(primary_cue[1], 4), | |
| "evidence": f"{round(primary_cue[1] * 100)}% attention on {primary_cue[0].replace('_', ' ')} tokens", | |
| } | |
| # Secondary code cue | |
| if len(sorted_cues) > 1: | |
| head_entry["secondary_cue"] = { | |
| "type": sorted_cues[1][0], | |
| "score": round(sorted_cues[1][1], 4), | |
| } | |
| critical_heads.append(head_entry) | |
| # Sort by max_weight (return all heads, frontend will decide how many to display) | |
| critical_heads.sort(key=lambda h: h["max_weight"], reverse=True) | |
| # Layer-level pattern: majority vote of head patterns, weighted by confidence | |
| pattern_votes = {} | |
| for h in critical_heads: | |
| if h["pattern"] and h["pattern"]["type"]: | |
| pt = h["pattern"]["type"] | |
| pc = h["pattern"]["confidence"] | |
| pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc | |
| layer_pattern = None | |
| if pattern_votes: | |
| best_type = max(pattern_votes, key=pattern_votes.get) | |
| total_conf = sum(pattern_votes.values()) | |
| layer_pattern = { | |
| "type": best_type, | |
| "confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0 | |
| } | |
| layer_data_this_token.append({ | |
| "layer_idx": layer_idx, | |
| "pattern": layer_pattern, | |
| "critical_heads": critical_heads, | |
| "activation_magnitude": activation_magnitude, | |
| "activation_entropy": activation_entropy, | |
| "hidden_state_norm": hidden_state_norm, | |
| "delta_norm": delta_norm | |
| }) | |
| layer_data_by_token.append(layer_data_this_token) | |
| # Update inputs | |
| next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=manager.device) | |
| current_ids = torch.cat([current_ids, next_token_tensor], dim=1) | |
| # Stop on EOS | |
| if next_token_id == manager.tokenizer.eos_token_id: | |
| break | |
| # Free memory from this step's outputs to prevent accumulation | |
| # This is critical for large models like Devstral (40 layers, 32 heads) | |
| del outputs | |
| del logits | |
| del probs | |
| if 'layer_attn' in dir(): | |
| del layer_attn | |
| if 'current_hidden' in dir(): | |
| del current_hidden | |
| # Periodic garbage collection for large models (every 8 steps) | |
| if (step + 1) % 8 == 0: | |
| gc.collect() | |
| if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() if hasattr(torch.mps, 'empty_cache') else None | |
| # Clean up hooks after generation | |
| for hook in hooks: | |
| hook.remove() | |
| # Placeholder for Q/K/V data (will be populated in future iterations) | |
| qkv_by_layer_head = {} | |
| generation_time = time.time() - start_time | |
| # Calculate token section boundaries for UI display | |
| total_tokens = prompt_length + len(generated_token_ids) | |
| system_prompt_text = system_prompt_override or (model_config.get("system_prompt") if model_config else None) | |
| # For instruction models, estimate where system prompt ends | |
| # This is approximate due to control tokens in chat templates | |
| system_prompt_end = 0 | |
| if prompt_style == "instruction" and system_prompt_text: | |
| if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None: | |
| # For Devstral, try encoding with empty system to estimate boundary | |
| try: | |
| no_system_tokens = manager.mistral_tokenizer.encode_chat("", prompt) | |
| # The difference gives us system tokens, but we need to add 1 to include | |
| # the closing [/SYSTEM_PROMPT] tag in the system prompt section | |
| system_prompt_end = prompt_length - len(no_system_tokens) + 1 | |
| # Ensure non-negative and within bounds | |
| system_prompt_end = max(0, min(system_prompt_end, prompt_length)) | |
| logger.info(f"Estimated system prompt boundary: {system_prompt_end} tokens (includes closing tag)") | |
| except Exception as e: | |
| logger.warning(f"Could not estimate system prompt boundary: {e}") | |
| system_prompt_end = 0 | |
| else: | |
| # For other instruction models, rough estimate based on character ratio | |
| # This is very approximate but provides some visual separation | |
| total_chars = len(system_prompt_text or "") + len(prompt) | |
| if total_chars > 0: | |
| system_ratio = len(system_prompt_text or "") / total_chars | |
| system_prompt_end = int(prompt_length * system_ratio) | |
| token_sections = { | |
| "systemPrompt": { | |
| "start": 0, | |
| "end": system_prompt_end, | |
| "text": system_prompt_text, | |
| "tokenCount": system_prompt_end | |
| }, | |
| "userPrompt": { | |
| "start": system_prompt_end, | |
| "end": prompt_length, | |
| "text": prompt, | |
| "tokenCount": prompt_length - system_prompt_end | |
| }, | |
| "output": { | |
| "start": prompt_length, | |
| "end": total_tokens, | |
| "text": "".join(generated_tokens), | |
| "tokenCount": len(generated_token_ids) | |
| } | |
| } | |
| # Build token metadata for frontend (eliminates per-token API calls) | |
| from .tokenizer_utils import TokenizerMetadata | |
| token_metadata = TokenizerMetadata(manager.tokenizer) | |
| special_token_ids = { | |
| manager.tokenizer.eos_token_id, | |
| manager.tokenizer.bos_token_id, | |
| manager.tokenizer.pad_token_id, | |
| manager.tokenizer.unk_token_id | |
| } | |
| def build_token_data(token_ids, token_texts, token_type): | |
| """Build token data with full metadata for hover tooltips""" | |
| multi_split_flags = token_metadata.is_multi_split_identifier(token_ids) | |
| result = [] | |
| for i, (tid, t) in enumerate(zip(token_ids, token_texts)): | |
| bpe_pieces = token_metadata.get_subword_pieces(tid) | |
| result.append({ | |
| "text": t, | |
| "idx": tid, | |
| "bytes": len(t.encode('utf-8')), | |
| "type": token_type, | |
| "bpe_pieces": bpe_pieces, | |
| "is_special": tid in special_token_ids, | |
| "is_multi_split": multi_split_flags[i] if i < len(multi_split_flags) else False, | |
| "num_pieces": len(bpe_pieces), | |
| }) | |
| return result | |
| # Build response | |
| response = { | |
| "requestId": request_id, # For lazy-loading matrices via /matrix endpoint | |
| "prompt": prompt, | |
| "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"), | |
| "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"), | |
| "tokenSections": token_sections, # Section boundaries for UI coloring | |
| "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token | |
| "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps | |
| "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility | |
| "qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint | |
| "modelInfo": { | |
| "numLayers": n_layers, | |
| "numHeads": n_heads, | |
| "modelDimension": d_model, | |
| "headDim": head_dim, | |
| "vocabSize": manager.model.config.vocab_size | |
| }, | |
| "generationTime": generation_time, | |
| "numTokensGenerated": len(generated_tokens) | |
| } | |
| logger.info(f"✅ Research attention analysis complete: {len(generated_tokens)} tokens, {generation_time:.2f}s") | |
| # Sanitize response to handle NaN/Inf values that break JSON serialization | |
| return sanitize_for_json(response) | |
| except Exception as e: | |
| logger.error(f"Research attention analysis error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def sse_event(event_type: str, **kwargs) -> str: | |
| """Format data as SSE event""" | |
| data = {'type': event_type, 'timestamp': int(time.time() * 1000), **kwargs} | |
| return f"data: {json.dumps(data)}\n\n" | |
| async def analyze_research_attention_stream(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| SSE Streaming version of Research-Grade Attention Analysis | |
| Emits progress events during each stage: | |
| - tokenizing: Initial tokenization | |
| - generating: Per-token generation progress | |
| - extracting: Per-layer attention extraction | |
| - serializing: Building response | |
| - complete: Analysis finished | |
| - result: Final data payload | |
| """ | |
| async def event_generator(): | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Generate unique request ID for matrix cache lookup | |
| request_id = str(uuid.uuid4()) | |
| # Clear old cached matrices to free memory before starting new analysis | |
| matrix_cache.clear_old_requests(request_id) | |
| # Get parameters | |
| prompt = request.get("prompt", "def quicksort(arr):") | |
| max_tokens = request.get("max_tokens", 8) | |
| auto_complete = request.get("auto_complete", False) | |
| temperature = request.get("temperature", 0.7) | |
| top_k_param = request.get("top_k", None) # Top-k sampling parameter | |
| top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter | |
| # If auto_complete mode, ensure we have a reasonable upper limit | |
| if auto_complete: | |
| max_tokens = min(max_tokens, 128) | |
| logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}, auto_complete={auto_complete}, request_id={request_id}") | |
| # === STAGE 1: TOKENIZING === | |
| yield sse_event('tokenizing', stage=1, totalStages=5, progress=2, | |
| stageProgress=0, detail=f'Tokenizing {len(prompt)} characters...') | |
| # Get model config for prompt formatting | |
| from .model_config import get_model_config | |
| from .prompt_formatter import format_prompt | |
| model_config = get_model_config(manager.model_id) | |
| # Get optional system prompt override from request | |
| system_prompt_override = request.get("system_prompt") | |
| # Format prompt using the unified formatter | |
| formatted_prompt = format_prompt( | |
| prompt=prompt, | |
| model_config=model_config or {}, | |
| tokenizer=manager.tokenizer, | |
| system_prompt_override=system_prompt_override | |
| ) | |
| prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion" | |
| # Temperature is now controlled by the frontend UI | |
| # The frontend sets appropriate defaults per model (0.15 for Devstral, 0.7 for CodeGen) | |
| logger.info(f"[SSE] Using temperature={temperature} from request") | |
| # Tokenize and prepare - use MistralTokenizer for Devstral | |
| if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None: | |
| system_prompt = system_prompt_override or (model_config.get("system_prompt") if model_config else "") | |
| prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt) | |
| inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)} | |
| prompt_length = len(prompt_token_ids) | |
| prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids] | |
| else: | |
| inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device) | |
| prompt_length = inputs["input_ids"].shape[1] | |
| prompt_token_ids = inputs["input_ids"][0].tolist() | |
| prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids] | |
| yield sse_event('tokenizing', stage=1, totalStages=5, progress=8, | |
| stageProgress=100, detail=f'Tokenized into {prompt_length} tokens', | |
| metadata={'tokenCount': prompt_length}) | |
| await asyncio.sleep(0) # Yield to event loop | |
| # Storage for generation | |
| generated_token_ids = [] | |
| generated_tokens = [] | |
| # Model info | |
| n_layers = len(list(manager.model.parameters())) | |
| if hasattr(manager.model.config, 'n_layer'): | |
| n_layers = manager.model.config.n_layer | |
| elif hasattr(manager.model.config, 'num_hidden_layers'): | |
| n_layers = manager.model.config.num_hidden_layers | |
| n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads | |
| d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size | |
| head_dim = d_model // n_heads | |
| # === STAGE 2: GENERATING === | |
| layer_data_by_token = [] | |
| token_alternatives_by_step = [] | |
| # Hook system to capture Q/K/V matrices | |
| qkv_captures = {} | |
| hooks = [] | |
| # Hook for combined QKV projection (CodeGen style) | |
| def make_combined_qkv_hook(layer_idx): | |
| def hook(module, input, output): | |
| try: | |
| if output.dim() != 3: | |
| return | |
| batch_size, seq_len, hidden = output.shape | |
| expected_hidden = 3 * n_heads * head_dim | |
| if hidden != expected_hidden: | |
| return | |
| qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim) | |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
| qkv_captures[layer_idx] = { | |
| 'q': q[0].detach().cpu(), | |
| 'k': k[0].detach().cpu(), | |
| 'v': v[0].detach().cpu() | |
| } | |
| except Exception: | |
| pass | |
| return hook | |
| # Hooks for separate Q, K, V projections (Mistral/LLaMA style) | |
| def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None): | |
| """Create hook for separate Q/K/V projection modules. | |
| For GQA models, K and V have fewer heads than Q, so we need to | |
| expand them to match Q's head count for consistent visualization. | |
| """ | |
| def hook(module, input, output): | |
| try: | |
| if output.dim() != 3: | |
| return | |
| batch_size, seq_len, hidden = output.shape | |
| # Determine number of heads for this projection | |
| if proj_type == 'q': | |
| proj_heads = n_heads | |
| else: | |
| # K and V may have fewer heads (GQA) | |
| proj_heads = num_kv_heads if num_kv_heads else n_heads | |
| proj_head_dim = hidden // proj_heads | |
| if hidden != proj_heads * proj_head_dim: | |
| return | |
| # Reshape to [batch, seq, heads, head_dim] | |
| proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim) | |
| # For GQA, expand K/V to match Q's head count | |
| if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads: | |
| # Repeat each KV head to match Q heads | |
| repeat_factor = n_heads // num_kv_heads | |
| proj_output = proj_output.repeat_interleave(repeat_factor, dim=2) | |
| # Initialize layer entry if needed | |
| if layer_idx not in qkv_captures: | |
| qkv_captures[layer_idx] = {} | |
| qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu() | |
| except Exception as e: | |
| logger.debug(f"QKV capture error for layer {layer_idx} {proj_type}: {e}") | |
| return hook | |
| # Register hooks based on model architecture | |
| try: | |
| # CodeGen style: model.transformer.h[layer].attn.qkv_proj | |
| if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'): | |
| for layer_idx, layer in enumerate(manager.model.transformer.h): | |
| if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'): | |
| hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx)) | |
| hooks.append(hook) | |
| elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'): | |
| hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx)) | |
| hooks.append(hook) | |
| # Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj | |
| elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'): | |
| num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None) | |
| for layer_idx, layer in enumerate(manager.model.model.layers): | |
| if hasattr(layer, 'self_attn'): | |
| attn = layer.self_attn | |
| if hasattr(attn, 'q_proj'): | |
| hook = attn.q_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'q', num_kv_heads)) | |
| hooks.append(hook) | |
| if hasattr(attn, 'k_proj'): | |
| hook = attn.k_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'k', num_kv_heads)) | |
| hooks.append(hook) | |
| if hasattr(attn, 'v_proj'): | |
| hook = attn.v_proj.register_forward_hook( | |
| make_separate_proj_hook(layer_idx, 'v', num_kv_heads)) | |
| hooks.append(hook) | |
| logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)") | |
| except Exception as hook_error: | |
| logger.warning(f"Could not register QKV hooks: {hook_error}") | |
| # Phase 4: Hooks for attention and MLP output norms + gate activation stats | |
| attn_output_norms = {} | |
| mlp_output_norms = {} | |
| gate_activation_stats = {} | |
| def make_attn_output_hook(layer_idx): | |
| def hook(module, input, output): | |
| try: | |
| out = output[0] if isinstance(output, tuple) else output | |
| if out.dim() == 3: | |
| attn_output_norms[layer_idx] = torch.norm(out[0, -1]).item() | |
| except Exception: | |
| pass | |
| return hook | |
| def make_mlp_output_hook(layer_idx): | |
| def hook(module, input, output): | |
| try: | |
| out = output[0] if isinstance(output, tuple) else output | |
| if out.dim() == 3: | |
| mlp_output_norms[layer_idx] = torch.norm(out[0, -1]).item() | |
| elif out.dim() == 2: | |
| mlp_output_norms[layer_idx] = torch.norm(out[-1]).item() | |
| except Exception: | |
| pass | |
| return hook | |
| def make_gate_hook(layer_idx): | |
| """Capture gate activation stats for SwiGLU FFN (LLaMA/Mistral).""" | |
| def hook(module, input, output): | |
| try: | |
| inp = input[0] if isinstance(input, tuple) else input | |
| if inp.dim() == 3: | |
| inp = inp[0, -1] # last token | |
| elif inp.dim() == 2: | |
| inp = inp[-1] | |
| if hasattr(module, 'gate_proj'): | |
| gate_out = torch.nn.functional.silu(module.gate_proj(inp)) | |
| abs_gate = gate_out.abs() | |
| gate_activation_stats[layer_idx] = { | |
| "sparsity": round(float((abs_gate < 0.01).float().mean().item()), 4), | |
| "mean": round(float(gate_out.mean().item()), 4), | |
| "max": round(float(gate_out.max().item()), 4), | |
| } | |
| except Exception: | |
| pass | |
| return hook | |
| # Cache for decoded token texts (reused across heads within a step) | |
| step_token_texts_cache: Dict[str, Any] = {} | |
| # Detect FFN type from first layer | |
| ffn_type = "gelu" # default | |
| try: | |
| # CodeGen style | |
| if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'): | |
| for layer_idx, layer in enumerate(manager.model.transformer.h): | |
| if hasattr(layer, 'attn'): | |
| hook = layer.attn.register_forward_hook(make_attn_output_hook(layer_idx)) | |
| hooks.append(hook) | |
| if hasattr(layer, 'mlp'): | |
| hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx)) | |
| hooks.append(hook) | |
| # Mistral/LLaMA style | |
| elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'): | |
| for layer_idx, layer in enumerate(manager.model.model.layers): | |
| if hasattr(layer, 'self_attn'): | |
| hook = layer.self_attn.register_forward_hook(make_attn_output_hook(layer_idx)) | |
| hooks.append(hook) | |
| if hasattr(layer, 'mlp'): | |
| hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx)) | |
| hooks.append(hook) | |
| # Gate hook for SwiGLU models | |
| if hasattr(layer.mlp, 'gate_proj'): | |
| hook = layer.mlp.register_forward_hook(make_gate_hook(layer_idx)) | |
| hooks.append(hook) | |
| if layer_idx == 0: | |
| ffn_type = "swiglu" | |
| logger.info(f"Registered attn/MLP output hooks for contribution tracking (ffn_type={ffn_type})") | |
| except Exception as hook_error: | |
| logger.warning(f"Could not register attn/MLP hooks: {hook_error}") | |
| with torch.no_grad(): | |
| current_ids = inputs["input_ids"] | |
| for step in range(max_tokens): | |
| # Emit progress for this generation step | |
| step_progress = (step / max_tokens) * 100 | |
| overall_progress = 10 + (step / max_tokens) * 20 # 10-30% | |
| yield sse_event('generating', stage=2, totalStages=5, progress=overall_progress, | |
| stageProgress=step_progress, | |
| detail=f'Generating token {step + 1}/{max_tokens}', | |
| metadata={'stepIndex': step, 'totalSteps': max_tokens}) | |
| await asyncio.sleep(0) | |
| qkv_captures.clear() | |
| attn_output_norms.clear() | |
| mlp_output_norms.clear() | |
| gate_activation_stats.clear() | |
| # Forward pass with full outputs | |
| outputs = manager.model( | |
| current_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Get logits for next token | |
| raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling | |
| # Capture raw logits for top-10 tokens (before temperature scaling) | |
| import math as math_module | |
| top_n_display = 10 # Get top 10 alternatives for display | |
| top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits))) | |
| # Build raw logits entries (before temperature) | |
| # Use correct tokenizer for Devstral vs other models | |
| def decode_token(tid): | |
| if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None: | |
| return manager.mistral_tokenizer.decode_token(tid) | |
| else: | |
| return manager.tokenizer.decode([tid], skip_special_tokens=False) | |
| logits_entries = [] | |
| for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())): | |
| token_text = decode_token(idx) | |
| logits_entries.append({ | |
| "token": token_text, | |
| "token_id": idx, | |
| "logit": logit_val, | |
| "rank": rank + 1 | |
| }) | |
| # Greedy token (argmax of raw logits, before any sampling) | |
| greedy_token_id = torch.argmax(raw_logits).item() | |
| greedy_token = decode_token(greedy_token_id) | |
| # Compute raw probabilities (T=1) for comparison visualization | |
| raw_probs = torch.softmax(raw_logits, dim=0) | |
| # Apply temperature scaling | |
| logits = raw_logits.clone() | |
| if temperature > 0: | |
| logits = logits / temperature | |
| probs = torch.softmax(logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if top_k_param is not None and top_k_param > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs))) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p (nucleus) filtering if specified | |
| if top_p_param is not None and top_p_param < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| cutoff_mask = cumulative_probs > top_p_param | |
| cutoff_mask[1:] = cutoff_mask[:-1].clone() | |
| cutoff_mask[0] = False | |
| sorted_probs[cutoff_mask] = 0 | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered.scatter_(0, sorted_indices, sorted_probs) | |
| if probs_filtered.sum() > 0: | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| if temperature == 0: | |
| next_token_id = torch.argmax(probs_filtered, dim=-1).item() | |
| else: | |
| if probs_filtered.sum() > 0: | |
| next_token_id = torch.multinomial(probs_filtered, 1).item() | |
| else: | |
| next_token_id = torch.argmax(probs, dim=-1).item() | |
| next_token_text = decode_token(next_token_id) | |
| generated_token_ids.append(next_token_id) | |
| generated_tokens.append(next_token_text) | |
| # Capture top-10 token alternatives with probabilities | |
| # Use log_softmax for numerical stability at low temperatures | |
| _, top_indices = torch.topk(logits, k=min(top_n_display, len(logits))) | |
| # Use log_softmax (numerically stable) then exp() for probabilities | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| top_probs = torch.exp(log_probs[top_indices]) | |
| alternatives = [] | |
| cumulative = 0.0 | |
| selected_in_top = False | |
| for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())): | |
| token_text = decode_token(idx) | |
| cumulative += prob | |
| if idx == next_token_id: | |
| selected_in_top = True | |
| alternatives.append({ | |
| "token": token_text, | |
| "token_id": idx, | |
| "probability": prob, | |
| "raw_probability": raw_probs[idx].item(), # T=1 probability for comparison | |
| "log_probability": math_module.log(prob) if prob > 0 else float('-inf'), | |
| "cumulative_probability": cumulative, | |
| "rank": rank + 1 | |
| }) | |
| # If selected token is not in top-N, add it with its actual probability | |
| if not selected_in_top: | |
| selected_prob = probs[next_token_id].item() | |
| selected_raw_prob = raw_probs[next_token_id].item() | |
| selected_log_prob = log_probs[next_token_id].item() | |
| selected_logit = raw_logits[next_token_id].item() | |
| # Find the rank of the selected token | |
| sorted_indices = torch.argsort(raw_logits, descending=True) | |
| selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1 | |
| alternatives.append({ | |
| "token": next_token_text, | |
| "token_id": next_token_id, | |
| "probability": selected_prob, | |
| "raw_probability": selected_raw_prob, # T=1 probability for comparison | |
| "log_probability": selected_log_prob, | |
| "cumulative_probability": None, | |
| "rank": selected_rank, | |
| "is_selected_outlier": True | |
| }) | |
| # Also add to logits if not present | |
| if next_token_id not in [e["token_id"] for e in logits_entries]: | |
| logits_entries.append({ | |
| "token": next_token_text, | |
| "token_id": next_token_id, | |
| "logit": selected_logit, | |
| "rank": selected_rank, | |
| "is_selected_outlier": True | |
| }) | |
| # Build sampling metadata | |
| sampling_metadata = { | |
| "temperature": temperature, | |
| "top_k": top_k_param, | |
| "top_p": top_p_param, | |
| "greedy_token_id": greedy_token_id, | |
| "greedy_token": greedy_token, | |
| "was_greedy": next_token_id == greedy_token_id | |
| } | |
| # --- Margin computation and stability classification --- | |
| import math as _math_margin | |
| winner_logit = logits_entries[0]["logit"] if len(logits_entries) > 0 else 0.0 | |
| runnerup_logit = logits_entries[1]["logit"] if len(logits_entries) > 1 else winner_logit | |
| margin = winner_logit - runnerup_logit | |
| runnerup_token = logits_entries[1]["token"] if len(logits_entries) > 1 else "" | |
| # Entropy over top-k probabilities | |
| top_probs_list_for_entropy = [a["probability"] for a in alternatives[:10] if a["probability"] > 0] | |
| margin_entropy = -sum(p * _math_margin.log(p) for p in top_probs_list_for_entropy) if top_probs_list_for_entropy else 0.0 | |
| stability = _classify_stability(margin) | |
| # Greedy margin: margin computed from raw logits (temperature=0) | |
| raw_sorted_logits, raw_sorted_indices = torch.topk(raw_logits, k=min(2, len(raw_logits))) | |
| raw_sorted_list = raw_sorted_logits.tolist() | |
| greedy_margin = (raw_sorted_list[0] - raw_sorted_list[1]) if len(raw_sorted_list) >= 2 else 0.0 | |
| # Sampling sensitivity: did temperature change the outcome? | |
| sampling_sensitive = next_token_id != greedy_token_id | |
| winner_token = logits_entries[0]["token"] if len(logits_entries) > 0 else "" | |
| margin_data = { | |
| "margin": margin, | |
| "winner_token": winner_token, | |
| "winner_logit": winner_logit, | |
| "runnerup_logit": runnerup_logit, | |
| "runnerup_token": runnerup_token, | |
| "entropy": margin_entropy, | |
| "stability": stability, | |
| "greedy_margin": greedy_margin, | |
| "sampling_sensitive": sampling_sensitive, | |
| } | |
| token_alternatives_by_step.append({ | |
| "step": step, | |
| "selected_token": next_token_text, | |
| "selected_token_id": next_token_id, | |
| "alternatives": alternatives, | |
| "logits": logits_entries, | |
| "sampling": sampling_metadata, | |
| "margin": margin_data, | |
| }) | |
| # Cache hidden states, logits, and full sequence for intervention endpoint | |
| try: | |
| hidden_state_cache.store_step(request_id, step, outputs.hidden_states, raw_logits, current_ids) | |
| if step == 0: | |
| hidden_state_cache.store_input_ids(request_id, current_ids[:, :-1]) # prompt only | |
| except Exception as hs_err: | |
| logger.debug(f"Hidden state cache error at step {step}: {hs_err}") | |
| # Emit generated token immediately so clients can show code progressively | |
| yield sse_event('generated_token', stage=2, totalStages=5, | |
| progress=10 + ((step + 1) / max_tokens) * 20, | |
| stageProgress=((step + 1) / max_tokens) * 100, | |
| detail=f'Generated token {step + 1}/{max_tokens}', | |
| metadata={ | |
| 'stepIndex': step, | |
| 'totalSteps': max_tokens, | |
| 'token': next_token_text, | |
| 'tokenId': next_token_id, | |
| 'generatedTokens': generated_tokens.copy(), | |
| }) | |
| await asyncio.sleep(0) | |
| # === STAGE 3: EXTRACTING (per layer within each token) === | |
| # Optimised: batched tensor ops per layer instead of per-head Python loops | |
| # Reduces GPU→CPU sync points from ~4000 to ~40 per token | |
| layer_data_this_token = [] | |
| n_total_layers = len(outputs.attentions) | |
| # Margin contribution decomposition: compute the "logit difference direction" | |
| # (W_U[winner] - W_U[runner-up]) once, then dot with each layer's residual | |
| margin_diff_direction = None | |
| winner_token_id_for_decomp = logits_entries[0]["token_id"] if len(logits_entries) > 0 else None | |
| runnerup_token_id_for_decomp = logits_entries[1]["token_id"] if len(logits_entries) > 1 else None | |
| if winner_token_id_for_decomp is not None and runnerup_token_id_for_decomp is not None: | |
| try: | |
| lm_head_weight = manager.model.lm_head.weight # [vocab_size, d_model] | |
| margin_diff_direction = (lm_head_weight[winner_token_id_for_decomp] - lm_head_weight[runnerup_token_id_for_decomp]).detach() | |
| except Exception: | |
| margin_diff_direction = None | |
| for layer_idx in range(n_total_layers): | |
| # Emit extraction progress (within generating stage for combined progress) | |
| if step == max_tokens - 1: # Only emit detailed layer progress on last token | |
| layer_progress = (layer_idx / n_total_layers) * 100 | |
| overall_progress = 30 + (layer_idx / n_total_layers) * 40 # 30-70% | |
| yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress, | |
| stageProgress=layer_progress, | |
| detail=f'Processing layer {layer_idx + 1}/{n_total_layers}', | |
| metadata={'layerIndex': layer_idx, 'totalLayers': n_total_layers, | |
| 'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens}) | |
| if layer_idx % 5 == 0: | |
| await asyncio.sleep(0) | |
| # --- Per-layer: bulk GPU ops then single CPU transfer --- | |
| layer_attn = outputs.attentions[layer_idx][0] # [n_heads, seq_len, seq_len] | |
| current_hidden = outputs.hidden_states[layer_idx + 1] | |
| if current_hidden.dim() == 3: | |
| current_hidden = current_hidden[0] | |
| # Hidden state metrics — 4 values, one .cpu() call | |
| last_token_hidden = current_hidden[-1] | |
| if layer_idx > 0: | |
| prev_hidden = outputs.hidden_states[layer_idx] | |
| if prev_hidden.dim() == 3: | |
| prev_hidden = prev_hidden[0] | |
| hidden_metrics = torch.stack([ | |
| torch.norm(current_hidden - prev_hidden), | |
| torch.norm(current_hidden), | |
| torch.std(last_token_hidden), | |
| torch.norm(last_token_hidden), | |
| ]).cpu().tolist() | |
| delta_norm, activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics | |
| else: | |
| hidden_metrics = torch.stack([ | |
| torch.norm(current_hidden), | |
| torch.std(last_token_hidden), | |
| torch.norm(last_token_hidden), | |
| ]).cpu().tolist() | |
| activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics | |
| delta_norm = None | |
| # Sanitize hidden state metrics | |
| activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude | |
| activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy | |
| hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm | |
| if delta_norm is not None: | |
| delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm | |
| # Margin contribution decomposition: | |
| # margin_contribution = (W_U[winner] - W_U[runner-up]) · (h_{ℓ+1} - h_ℓ) | |
| # This causally attributes the final margin to each layer's residual contribution. | |
| margin_contribution = None | |
| if margin_diff_direction is not None: | |
| try: | |
| if layer_idx > 0: | |
| prev_h = outputs.hidden_states[layer_idx] | |
| if prev_h.dim() == 3: | |
| prev_h = prev_h[0] | |
| residual = current_hidden[-1] - prev_h[-1] | |
| else: | |
| # Layer 0: the embedding contribution | |
| residual = current_hidden[-1] | |
| mc = torch.dot(margin_diff_direction, residual).item() | |
| margin_contribution = 0.0 if math.isnan(mc) or math.isinf(mc) else mc | |
| except Exception: | |
| margin_contribution = None | |
| # --- Batched head processing: all heads at once on GPU --- | |
| num_heads_layer = layer_attn.shape[0] | |
| # Last-row attention weights for all heads: [n_heads, seq_len] | |
| all_last_row = layer_attn[:, -1, :] | |
| # Max weight per head: [n_heads] — single GPU op | |
| all_max_weights = all_last_row.max(dim=-1).values | |
| # Entropy of last-row per head: [n_heads] — single GPU op | |
| all_entropies = -(all_last_row * torch.log(all_last_row + 1e-10)).sum(dim=-1) | |
| # Normalized average entropy per head (latter half of query positions) | |
| # layer_attn: [n_heads, q_len, k_len] | |
| q_len = layer_attn.shape[1] | |
| # Raw entropy per query position per head: [n_heads, q_len] | |
| all_token_entropies = -(layer_attn * torch.log(layer_attn + 1e-10)).sum(dim=-1) | |
| # Normalize by log(position): [q_len] | |
| positions = torch.arange(1, q_len + 1, device=layer_attn.device, dtype=layer_attn.dtype) | |
| max_ents = torch.log(positions + 1e-10) # [q_len] | |
| all_normalized = all_token_entropies / (max_ents.unsqueeze(0) + 1e-10) # [n_heads, q_len] | |
| # Average over latter half: [n_heads] | |
| start_idx = q_len // 2 | |
| if start_idx < q_len: | |
| all_avg_entropies = all_normalized[:, start_idx:].mean(dim=-1) | |
| else: | |
| all_avg_entropies = all_normalized.mean(dim=-1) | |
| # Previous-token weights for pattern detection: [n_heads] | |
| all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device) | |
| # Attention sink weights: sum of attention on positions 0-2 per head [n_heads] | |
| seq_len_attn = all_last_row.shape[1] | |
| all_sink_weights = all_last_row[:, :min(3, seq_len_attn)].sum(dim=-1) | |
| # Local attention weights: sum within 5 positions of query per head [n_heads] | |
| all_local_weights = all_last_row[:, max(0, seq_len_attn - 5):].sum(dim=-1) if seq_len_attn > 5 else torch.ones(num_heads_layer, device=layer_attn.device) | |
| # Induction detection: attention to positions following previous occurrences of current token | |
| all_induction_weights = torch.zeros(num_heads_layer, device=layer_attn.device) | |
| if step > 0: | |
| current_token = current_ids[0, -1] | |
| prev_occurrences = (current_ids[0, :-1] == current_token).nonzero(as_tuple=True)[0] | |
| if len(prev_occurrences) > 0: | |
| following_positions = prev_occurrences + 1 | |
| following_positions = following_positions[following_positions < seq_len_attn] | |
| if len(following_positions) > 0: | |
| all_induction_weights = all_last_row[:, following_positions].sum(dim=-1) | |
| # Single bulk transfer: all head metrics to CPU | |
| head_metrics_gpu = torch.stack([ | |
| all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights, | |
| all_sink_weights, all_local_weights, all_induction_weights | |
| ]) # [7, n_heads] | |
| head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point | |
| max_weights_list = head_metrics_cpu[0] | |
| entropies_list = head_metrics_cpu[1] | |
| avg_entropies_list = head_metrics_cpu[2] | |
| prev_token_list = head_metrics_cpu[3] | |
| sink_weights_list = head_metrics_cpu[4] | |
| local_weights_list = head_metrics_cpu[5] | |
| induction_weights_list = head_metrics_cpu[6] | |
| # Bulk transfer attention matrices to CPU: one .cpu() for entire layer | |
| layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len] | |
| # QKV matrices (already on CPU from hooks) | |
| qkv_layer = qkv_captures.get(layer_idx) | |
| # Build per-head metadata from CPU-side data (no more GPU calls) | |
| critical_heads = [] | |
| for head_idx in range(num_heads_layer): | |
| mw = max_weights_list[head_idx] | |
| ent = entropies_list[head_idx] | |
| avg_ent = avg_entropies_list[head_idx] | |
| ptw = prev_token_list[head_idx] | |
| skw = sink_weights_list[head_idx] | |
| lcw = local_weights_list[head_idx] | |
| idw = induction_weights_list[head_idx] | |
| # Sanitize | |
| mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw | |
| ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent | |
| avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent | |
| # Score-all-then-rank head classification | |
| # Behaviour type scores (attention geometry) | |
| behaviour_scores = { | |
| "attention_sink": skw, | |
| "previous_token": ptw, | |
| "local": lcw, | |
| "induction": min(1.0, idw), | |
| "focused": max(0.0, 1.0 - ent) if ent < 1.5 else 0.0, | |
| "diffuse": min(1.0, max(0.0, (ent - 1.0) / 2.0)), | |
| } | |
| behaviour_thresholds = { | |
| "attention_sink": 0.4, | |
| "previous_token": 0.7, | |
| "local": 0.5, | |
| "induction": 0.2, | |
| "focused": 0.3, | |
| "diffuse": 0.3, | |
| } | |
| qualified = { | |
| k: v for k, v in behaviour_scores.items() | |
| if v >= behaviour_thresholds.get(k, 0.3) | |
| } | |
| sorted_behaviours = sorted(qualified.items(), key=lambda x: x[1], reverse=True) | |
| primary = sorted_behaviours[0] if sorted_behaviours else ("diffuse", behaviour_scores["diffuse"]) | |
| secondary = sorted_behaviours[1] if len(sorted_behaviours) > 1 else None | |
| pattern_type = primary[0] | |
| confidence = primary[1] | |
| confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence | |
| # Code cue scores (what code tokens are attended to) | |
| # Decode token texts once per step (cached via nonlocal) | |
| if step_token_texts_cache.get('step') != step: | |
| try: | |
| step_token_texts_cache['texts'] = [ | |
| manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_attn].tolist() | |
| ] | |
| except Exception: | |
| step_token_texts_cache['texts'] = [] | |
| step_token_texts_cache['step'] = step | |
| token_texts = step_token_texts_cache.get('texts', []) | |
| code_cues = {} | |
| if len(token_texts) == seq_len_attn: | |
| head_weights = all_last_row[head_idx].cpu() | |
| delimiters = {'(', ')', '{', '}', '[', ']', ':', ';', ','} | |
| delim_indices = [i for i, t in enumerate(token_texts) if t.strip() in delimiters] | |
| if delim_indices: | |
| code_cues["delimiter_sensitive"] = head_weights[delim_indices].sum().item() | |
| keywords = {'def', 'return', 'if', 'else', 'elif', 'for', 'while', 'class', | |
| 'import', 'from', 'try', 'except', 'with', 'as', 'in', 'not', | |
| 'and', 'or', 'True', 'False', 'None', 'self', 'yield', 'async', | |
| 'await', 'lambda', 'raise', 'pass', 'break', 'continue', | |
| 'function', 'const', 'let', 'var', 'new', 'this', | |
| 'public', 'private', 'static', 'void', 'int', 'string', 'bool', | |
| 'namespace', 'using', 'class', 'interface', 'override', 'virtual'} | |
| kw_indices = [i for i, t in enumerate(token_texts) if t.strip() in keywords] | |
| if kw_indices: | |
| code_cues["keyword_sensitive"] = head_weights[kw_indices].sum().item() | |
| if idw > 0.15: | |
| code_cues["pattern_reuse"] = min(1.0, idw * 1.5) | |
| cue_threshold = 0.15 | |
| sorted_cues = sorted( | |
| [(k, round(v, 4)) for k, v in code_cues.items() if v >= cue_threshold], | |
| key=lambda x: x[1], reverse=True | |
| ) | |
| primary_cue = sorted_cues[0] if sorted_cues else None | |
| attention_matrix = layer_attn_cpu[head_idx] | |
| q_matrix = None | |
| k_matrix = None | |
| v_matrix = None | |
| if qkv_layer is not None: | |
| q_matrix = qkv_layer['q'][:, head_idx, :].float().numpy() | |
| k_matrix = qkv_layer['k'][:, head_idx, :].float().numpy() | |
| v_matrix = qkv_layer['v'][:, head_idx, :].float().numpy() | |
| matrix_cache.store(request_id, step, layer_idx, head_idx, { | |
| "attention_weights": attention_matrix, | |
| "q_matrix": q_matrix, | |
| "k_matrix": k_matrix, | |
| "v_matrix": v_matrix | |
| }) | |
| head_entry = { | |
| "head_idx": head_idx, | |
| "entropy": ent, | |
| "avg_entropy": avg_ent, | |
| "max_weight": mw, | |
| "has_matrices": attention_matrix is not None, | |
| "pattern": {"type": pattern_type, "confidence": round(confidence, 4)} if pattern_type else None, | |
| } | |
| if secondary: | |
| head_entry["secondary_behaviour"] = {"type": secondary[0], "score": round(secondary[1], 4)} | |
| if primary_cue: | |
| head_entry["code_cue"] = { | |
| "type": primary_cue[0], | |
| "score": primary_cue[1], | |
| "evidence": f"{round(primary_cue[1] * 100)}% attention on {primary_cue[0].replace('_', ' ')} tokens", | |
| } | |
| if len(sorted_cues) > 1: | |
| head_entry["secondary_cue"] = {"type": sorted_cues[1][0], "score": sorted_cues[1][1]} | |
| critical_heads.append(head_entry) | |
| critical_heads.sort(key=lambda h: h["max_weight"], reverse=True) | |
| # Layer-level pattern: majority vote of head patterns, weighted by confidence | |
| pattern_votes = {} | |
| for h in critical_heads: | |
| if h["pattern"] and h["pattern"]["type"]: | |
| pt = h["pattern"]["type"] | |
| pc = h["pattern"]["confidence"] | |
| pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc | |
| layer_pattern = None | |
| if pattern_votes: | |
| best_type = max(pattern_votes, key=pattern_votes.get) | |
| total_conf = sum(pattern_votes.values()) | |
| layer_pattern = { | |
| "type": best_type, | |
| "confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0 | |
| } | |
| layer_entry = { | |
| "layer_idx": layer_idx, | |
| "pattern": layer_pattern, | |
| "critical_heads": critical_heads, | |
| "activation_magnitude": activation_magnitude, | |
| "activation_entropy": activation_entropy, | |
| "hidden_state_norm": hidden_state_norm, | |
| "delta_norm": delta_norm, | |
| "margin_contribution": margin_contribution, | |
| } | |
| # Phase 4: Attention and MLP output norms + contribution ratios | |
| if layer_idx in attn_output_norms: | |
| layer_entry["attn_output_norm"] = round(attn_output_norms[layer_idx], 4) | |
| if layer_idx in mlp_output_norms: | |
| layer_entry["mlp_output_norm"] = round(mlp_output_norms[layer_idx], 4) | |
| if layer_idx in attn_output_norms and layer_idx in mlp_output_norms: | |
| attn_n = attn_output_norms[layer_idx] | |
| mlp_n = mlp_output_norms[layer_idx] | |
| total = attn_n + mlp_n | |
| if total > 0: | |
| layer_entry["attn_contribution"] = round(attn_n / total, 4) | |
| layer_entry["ffn_contribution"] = round(mlp_n / total, 4) | |
| if layer_idx in gate_activation_stats: | |
| layer_entry["gate_stats"] = gate_activation_stats[layer_idx] | |
| # Phase 5: Logit lens at sampled layers (every 8th layer) | |
| logit_lens_stride = max(1, n_layers // 5) | |
| if layer_idx % logit_lens_stride == 0 or layer_idx == n_layers - 1: | |
| try: | |
| hidden_for_lens = current_hidden[-1].unsqueeze(0) # [1, hidden_dim] | |
| # Apply final layer norm then project through lm_head | |
| if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'): | |
| normed = manager.model.model.norm(hidden_for_lens) | |
| lens_logits = manager.model.lm_head(normed)[0] # [vocab_size] | |
| elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'): | |
| normed = manager.model.transformer.ln_f(hidden_for_lens) | |
| lens_logits = manager.model.lm_head(normed)[0] | |
| else: | |
| lens_logits = None | |
| if lens_logits is not None: | |
| lens_probs = torch.softmax(lens_logits, dim=-1) | |
| top_probs, top_ids = torch.topk(lens_probs, k=5) | |
| top_probs_list = top_probs.cpu().tolist() | |
| top_ids_list = top_ids.cpu().tolist() | |
| lens_entries = [] | |
| for tp, tid in zip(top_probs_list, top_ids_list): | |
| lens_entries.append({ | |
| "token": manager.tokenizer.decode([tid], skip_special_tokens=False), | |
| "probability": tp | |
| }) | |
| layer_entry["logit_lens_top"] = lens_entries | |
| # Layer-wise margin tracking (raw logit diff between top-1 and top-2) | |
| top2_logits, top2_ids = torch.topk(lens_logits, k=min(2, len(lens_logits))) | |
| top2_logits_list = top2_logits.cpu().tolist() | |
| top2_ids_list = top2_ids.cpu().tolist() | |
| layer_winner_token = manager.tokenizer.decode([top2_ids_list[0]], skip_special_tokens=False) | |
| layer_runnerup_token = manager.tokenizer.decode([top2_ids_list[1]], skip_special_tokens=False) if len(top2_ids_list) > 1 else "" | |
| layer_margin_val = (top2_logits_list[0] - top2_logits_list[1]) if len(top2_logits_list) > 1 else 0.0 | |
| layer_entry["layer_margin"] = layer_margin_val | |
| layer_entry["layer_winner"] = layer_winner_token | |
| layer_entry["layer_runnerup"] = layer_runnerup_token | |
| # Tuned lens: apply per-layer affine correction | |
| from .tuned_lens import tuned_lens_runtime | |
| if tuned_lens_runtime.available: | |
| try: | |
| corrected = tuned_lens_runtime.apply(layer_idx, hidden_for_lens) | |
| tuned_normed = normed.__class__ # reuse same LN path | |
| # Re-apply final LN + lm_head on corrected hidden | |
| if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'): | |
| tuned_normed = manager.model.model.norm(corrected) | |
| tuned_logits = manager.model.lm_head(tuned_normed)[0] | |
| elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'): | |
| tuned_normed = manager.model.transformer.ln_f(corrected) | |
| tuned_logits = manager.model.lm_head(tuned_normed)[0] | |
| else: | |
| tuned_logits = None | |
| if tuned_logits is not None: | |
| tuned_probs = torch.softmax(tuned_logits, dim=-1) | |
| tuned_top_probs, tuned_top_ids = torch.topk(tuned_probs, k=5) | |
| tuned_entries = [] | |
| for tp, tid in zip(tuned_top_probs.cpu().tolist(), tuned_top_ids.cpu().tolist()): | |
| tuned_entries.append({ | |
| "token": manager.tokenizer.decode([tid], skip_special_tokens=False), | |
| "probability": tp | |
| }) | |
| layer_entry["tuned_lens_top"] = tuned_entries | |
| tuned_top2_logits, tuned_top2_ids = torch.topk(tuned_logits, k=min(2, len(tuned_logits))) | |
| tuned_top2_logits_list = tuned_top2_logits.cpu().tolist() | |
| tuned_top2_ids_list = tuned_top2_ids.cpu().tolist() | |
| layer_entry["tuned_layer_winner"] = manager.tokenizer.decode([tuned_top2_ids_list[0]], skip_special_tokens=False) | |
| layer_entry["tuned_layer_runnerup"] = manager.tokenizer.decode([tuned_top2_ids_list[1]], skip_special_tokens=False) if len(tuned_top2_ids_list) > 1 else "" | |
| layer_entry["tuned_layer_margin"] = (tuned_top2_logits_list[0] - tuned_top2_logits_list[1]) if len(tuned_top2_logits_list) > 1 else 0.0 | |
| except Exception as tuned_err: | |
| logger.debug(f"Tuned lens error at layer {layer_idx}: {tuned_err}") | |
| except Exception as lens_err: | |
| logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}") | |
| layer_data_this_token.append(layer_entry) | |
| layer_data_by_token.append(layer_data_this_token) | |
| # Update inputs | |
| next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=manager.device) | |
| current_ids = torch.cat([current_ids, next_token_tensor], dim=1) | |
| # Stop on EOS | |
| if next_token_id == manager.tokenizer.eos_token_id: | |
| break | |
| # Free memory from this step's outputs to prevent accumulation | |
| # This is critical for large models like Devstral (40 layers, 32 heads) | |
| del outputs | |
| del logits | |
| del probs | |
| if 'layer_attn' in dir(): | |
| del layer_attn | |
| if 'current_hidden' in dir(): | |
| del current_hidden | |
| # Periodic garbage collection for large models (every 8 steps) | |
| if (step + 1) % 8 == 0: | |
| gc.collect() | |
| if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() if hasattr(torch.mps, 'empty_cache') else None | |
| # Clean up hooks | |
| for hook in hooks: | |
| hook.remove() | |
| # === STAGE 4: SERIALIZING === | |
| yield sse_event('serializing', stage=4, totalStages=5, progress=75, | |
| stageProgress=0, detail='Building response data...') | |
| await asyncio.sleep(0) | |
| qkv_by_layer_head = {} | |
| generation_time = time.time() - start_time | |
| # Calculate token section boundaries | |
| total_tokens = prompt_length + len(generated_token_ids) | |
| system_prompt_text = system_prompt_override or (model_config.get("system_prompt") if model_config else None) | |
| system_prompt_end = 0 | |
| if prompt_style == "instruction" and system_prompt_text: | |
| if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None: | |
| try: | |
| no_system_tokens = manager.mistral_tokenizer.encode_chat("", prompt) | |
| # Add 1 to include the closing [/SYSTEM_PROMPT] tag in system section | |
| system_prompt_end = prompt_length - len(no_system_tokens) + 1 | |
| system_prompt_end = max(0, min(system_prompt_end, prompt_length)) | |
| except Exception: | |
| system_prompt_end = 0 | |
| else: | |
| total_chars = len(system_prompt_text or "") + len(prompt) | |
| if total_chars > 0: | |
| system_ratio = len(system_prompt_text or "") / total_chars | |
| system_prompt_end = int(prompt_length * system_ratio) | |
| token_sections = { | |
| "systemPrompt": { | |
| "start": 0, | |
| "end": system_prompt_end, | |
| "text": system_prompt_text, | |
| "tokenCount": system_prompt_end | |
| }, | |
| "userPrompt": { | |
| "start": system_prompt_end, | |
| "end": prompt_length, | |
| "text": prompt, | |
| "tokenCount": prompt_length - system_prompt_end | |
| }, | |
| "output": { | |
| "start": prompt_length, | |
| "end": total_tokens, | |
| "text": "".join(generated_tokens), | |
| "tokenCount": len(generated_token_ids) | |
| } | |
| } | |
| yield sse_event('serializing', stage=4, totalStages=5, progress=82, | |
| stageProgress=50, detail='Building token metadata...') | |
| await asyncio.sleep(0) | |
| # Build token metadata | |
| from .tokenizer_utils import TokenizerMetadata | |
| token_metadata_builder = TokenizerMetadata(manager.tokenizer) | |
| special_token_ids_set = { | |
| manager.tokenizer.eos_token_id, | |
| manager.tokenizer.bos_token_id, | |
| manager.tokenizer.pad_token_id, | |
| manager.tokenizer.unk_token_id | |
| } | |
| def build_token_data(token_ids, token_texts, token_type): | |
| multi_split_flags = token_metadata_builder.is_multi_split_identifier(token_ids) | |
| result = [] | |
| for i, (tid, t) in enumerate(zip(token_ids, token_texts)): | |
| bpe_pieces = token_metadata_builder.get_subword_pieces(tid) | |
| result.append({ | |
| "text": t, | |
| "idx": tid, | |
| "bytes": len(t.encode('utf-8')), | |
| "type": token_type, | |
| "bpe_pieces": bpe_pieces, | |
| "is_special": tid in special_token_ids_set, | |
| "is_multi_split": multi_split_flags[i] if i < len(multi_split_flags) else False, | |
| "num_pieces": len(bpe_pieces), | |
| }) | |
| return result | |
| # Compute margin statistics and commitment summary | |
| margin_stats = {"fragile_count": 0, "boundary_count": 0, "moderate_count": 0, "stable_count": 0} | |
| commitment_layers = [] | |
| flip_count = 0 | |
| for step_data in token_alternatives_by_step: | |
| m = step_data.get("margin", {}) | |
| stab = m.get("stability", "stable") | |
| if stab == "fragile": | |
| margin_stats["fragile_count"] += 1 | |
| elif stab == "boundary": | |
| margin_stats["boundary_count"] += 1 | |
| elif stab == "moderate": | |
| margin_stats["moderate_count"] += 1 | |
| else: | |
| margin_stats["stable_count"] += 1 | |
| # Commitment layer and flip detection from layer data | |
| for step_idx, step_layers in enumerate(layer_data_by_token): | |
| lens_layers = [l for l in step_layers if l.get("layer_margin") is not None] | |
| if not lens_layers: | |
| continue | |
| # Find commitment layer: first layer where margin > 0.3 and stays positive | |
| step_commitment = None | |
| for i, ll in enumerate(lens_layers): | |
| if ll["layer_margin"] > 0.3: | |
| stays_positive = all(lens_layers[j]["layer_margin"] > 0 for j in range(i, len(lens_layers))) | |
| if stays_positive: | |
| step_commitment = ll["layer_idx"] | |
| break | |
| if step_commitment is not None: | |
| commitment_layers.append(step_commitment) | |
| # Count flips: where winner changes between consecutive lens layers | |
| for i in range(1, len(lens_layers)): | |
| prev_winner = (lens_layers[i-1].get("layer_winner") or "").strip() | |
| curr_winner = (lens_layers[i].get("layer_winner") or "").strip() | |
| if prev_winner and curr_winner and prev_winner != curr_winner: | |
| flip_count += 1 | |
| avg_commitment = sum(commitment_layers) / len(commitment_layers) if commitment_layers else n_layers | |
| late_threshold = n_layers * 0.75 | |
| late_count = sum(1 for cl in commitment_layers if cl > late_threshold) | |
| commitment_summary = { | |
| "avg_commitment_layer": round(avg_commitment, 1), | |
| "late_commitment_count": late_count, | |
| "flip_count": flip_count, | |
| } | |
| # Tuned lens commitment summary (parallel to raw) | |
| tuned_commitment_summary = None | |
| from .tuned_lens import tuned_lens_runtime | |
| if tuned_lens_runtime.available: | |
| tuned_commitment_layers = [] | |
| tuned_flip_count = 0 | |
| for step_idx, step_layers in enumerate(layer_data_by_token): | |
| tuned_lens_layers = [l for l in step_layers if l.get("tuned_layer_margin") is not None] | |
| if not tuned_lens_layers: | |
| continue | |
| step_commitment = None | |
| for i, ll in enumerate(tuned_lens_layers): | |
| if ll["tuned_layer_margin"] > 0.3: | |
| stays_positive = all(tuned_lens_layers[j]["tuned_layer_margin"] > 0 for j in range(i, len(tuned_lens_layers))) | |
| if stays_positive: | |
| step_commitment = ll["layer_idx"] | |
| break | |
| if step_commitment is not None: | |
| tuned_commitment_layers.append(step_commitment) | |
| for i in range(1, len(tuned_lens_layers)): | |
| prev_w = (tuned_lens_layers[i-1].get("tuned_layer_winner") or "").strip() | |
| curr_w = (tuned_lens_layers[i].get("tuned_layer_winner") or "").strip() | |
| if prev_w and curr_w and prev_w != curr_w: | |
| tuned_flip_count += 1 | |
| tuned_avg = sum(tuned_commitment_layers) / len(tuned_commitment_layers) if tuned_commitment_layers else n_layers | |
| tuned_late = sum(1 for cl in tuned_commitment_layers if cl > late_threshold) | |
| tuned_commitment_summary = { | |
| "avg_commitment_layer": round(tuned_avg, 1), | |
| "late_commitment_count": tuned_late, | |
| "flip_count": tuned_flip_count, | |
| } | |
| # Build response | |
| response = { | |
| "requestId": request_id, # For lazy-loading matrices via /matrix endpoint | |
| "prompt": prompt, | |
| "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"), | |
| "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"), | |
| "tokenSections": token_sections, | |
| "tokenAlternatives": token_alternatives_by_step, | |
| "layersDataByStep": layer_data_by_token, | |
| "layersData": layer_data_by_token[-1] if layer_data_by_token else [], | |
| "qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint | |
| "modelInfo": { | |
| "numLayers": n_layers, | |
| "numHeads": n_heads, | |
| "modelDimension": d_model, | |
| "headDim": head_dim, | |
| "vocabSize": manager.model.config.vocab_size, | |
| "tunedLensAvailable": tuned_lens_runtime.available, | |
| "ffnType": ffn_type, | |
| "intermediateSize": getattr(manager.model.config, 'intermediate_size', None), | |
| }, | |
| "generationTime": generation_time, | |
| "numTokensGenerated": len(generated_tokens), | |
| "marginStats": margin_stats, | |
| "commitmentSummary": commitment_summary, | |
| **({"tunedCommitmentSummary": tuned_commitment_summary} if tuned_commitment_summary else {}), | |
| } | |
| # Estimate response size | |
| response_json = json.dumps(sanitize_for_json(response)) | |
| response_size_bytes = len(response_json.encode('utf-8')) | |
| yield sse_event('serializing', stage=4, totalStages=5, progress=90, | |
| stageProgress=100, detail=f'Response ready ({response_size_bytes / 1024 / 1024:.1f}MB)', | |
| metadata={'responseSizeBytes': response_size_bytes}) | |
| await asyncio.sleep(0) | |
| # === STAGE 5: COMPLETE === | |
| yield sse_event('complete', stage=5, totalStages=5, progress=95, | |
| stageProgress=0, detail='Transferring data...', | |
| metadata={'responseSizeBytes': response_size_bytes, 'generationTimeMs': int(generation_time * 1000)}) | |
| logger.info(f"✅ [SSE] Research attention analysis complete: {len(generated_tokens)} tokens, {generation_time:.2f}s, {response_size_bytes / 1024 / 1024:.1f}MB") | |
| # Send final result | |
| yield sse_event('result', data=sanitize_for_json(response)) | |
| except Exception as e: | |
| logger.error(f"[SSE] Research attention analysis error: {e}") | |
| logger.error(traceback.format_exc()) | |
| yield sse_event('error', detail=str(e), stage=0, totalStages=5, progress=0, stageProgress=0) | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache, no-store, must-revalidate', | |
| 'Connection': 'keep-alive', | |
| 'X-Accel-Buffering': 'no', # Disable nginx/proxy buffering | |
| } | |
| ) | |
| async def get_attention_matrix( | |
| request_id: str, | |
| step: int, | |
| layer: int, | |
| head: int, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """ | |
| Retrieve cached attention/QKV matrices for a specific head. | |
| Used for lazy-loading matrix data when user clicks "View Matrix" in the frontend. | |
| Matrices are cached during the initial analysis and available for 60 minutes. | |
| Parameters: | |
| - request_id: UUID from the original analysis response | |
| - step: Generation step (0 = first generated token) | |
| - layer: Layer index (0-based) | |
| - head: Head index (0-based) | |
| Returns: | |
| - attention_weights: [seq_len, seq_len] attention matrix | |
| - q_matrix: [seq_len, head_dim] query projections | |
| - k_matrix: [seq_len, head_dim] key projections | |
| - v_matrix: [seq_len, head_dim] value projections | |
| """ | |
| data = matrix_cache.get(request_id, step, layer, head) | |
| if data is None: | |
| logger.warning(f"Matrix cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Matrix data not found. Cache may have expired (60 min TTL). Please re-analyze." | |
| ) | |
| logger.info(f"Matrix cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}") | |
| # Convert numpy arrays to lists for JSON serialization | |
| # Arrays are stored as numpy for memory efficiency, converted on-demand here | |
| response = {} | |
| for key, value in data.items(): | |
| if value is not None and hasattr(value, 'tolist'): | |
| response[key] = value.tolist() | |
| else: | |
| response[key] = value | |
| return response | |
| async def get_matrix_cache_stats(authenticated: bool = Depends(verify_api_key)): | |
| """Return matrix cache statistics for monitoring.""" | |
| return matrix_cache.get_stats() | |
| async def get_attention_row( | |
| request_id: str, | |
| step: int, | |
| layer: int, | |
| head: Optional[int] = None, | |
| aggregate_mode: str = "mean", | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """ | |
| Retrieve single attention row for overlay visualization. | |
| Returns the attention weights from the query token (at position `step`) | |
| to all preceding positions. This is a minimal payload for efficient | |
| lazy-loading in the attention overlay feature. | |
| Parameters: | |
| - request_id: UUID from the original analysis response | |
| - step: Generation step (0 = first generated token) | |
| - layer: Layer index (0-based) | |
| - head: Head index (0-based), or None for aggregated view | |
| - aggregate_mode: "mean" or "max" when head is None | |
| Returns: | |
| - attention_weights: List of attention weights [0..seq_len] | |
| - seq_len: Number of positions in the sequence | |
| - layer: Layer index | |
| - head: Head index (null if aggregated) | |
| - aggregate_mode: Mode used if aggregated (null otherwise) | |
| """ | |
| # Get number of heads from model config | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| config = manager.model.config | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 16)) | |
| if head is not None: | |
| # Fetch specific head | |
| attention_row = matrix_cache.get_attention_row(request_id, step, layer, head) | |
| if attention_row is None: | |
| logger.warning(f"Attention row cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze." | |
| ) | |
| logger.info(f"Attention row cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}") | |
| return { | |
| "attention_weights": attention_row, | |
| "seq_len": len(attention_row), | |
| "layer": layer, | |
| "head": head, | |
| "aggregate_mode": None | |
| } | |
| else: | |
| # Aggregate across all heads | |
| attention_row = matrix_cache.get_aggregate_row( | |
| request_id, step, layer, num_heads, aggregate_mode | |
| ) | |
| if attention_row is None: | |
| logger.warning(f"Attention row aggregate cache miss: request_id={request_id}, step={step}, layer={layer}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze." | |
| ) | |
| logger.info(f"Attention row aggregate cache hit: request_id={request_id}, step={step}, layer={layer}, mode={aggregate_mode}") | |
| return { | |
| "attention_weights": attention_row, | |
| "seq_len": len(attention_row), | |
| "layer": layer, | |
| "head": None, | |
| "aggregate_mode": aggregate_mode | |
| } | |
| # --- Phase 2: Intervention endpoint --- | |
| class InterventionRequest(BaseModel): | |
| request_id: str | |
| step: int | |
| intervention_type: str # "mask_system" | "mask_user_span" | "mask_generated" | "greedy" | "temperature_sweep" | "layer_ablation" | "head_ablation" | "expert_mask" | |
| params: dict = {} | |
| class InterventionResponse(BaseModel): | |
| original_margin: float | |
| recomputed_margin: float | |
| margin_shift: float | |
| original_stability: str | |
| recomputed_stability: str | |
| original_winner: str | |
| recomputed_winner: str | |
| winner_changed: bool | |
| details: dict = {} | |
| async def run_intervention(request: InterventionRequest, authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| Run an input ablation on a cached generation run. | |
| Re-evaluates a token position under modified conditions (input masking, component ablation, temperature sweep). | |
| """ | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not hidden_state_cache.has_run(request.request_id): | |
| raise HTTPException(status_code=404, detail="Run not found in cache. Cache may have expired (60 min TTL). Please re-generate.") | |
| cached_logits = hidden_state_cache.get_logits(request.request_id, request.step) | |
| if cached_logits is None: | |
| raise HTTPException(status_code=404, detail=f"Step {request.step} not found in cached run.") | |
| try: | |
| # Move logits to compute device | |
| raw_logits = cached_logits.to(manager.device) | |
| # Original margin (from raw logits) | |
| top2_orig, top2_orig_ids = torch.topk(raw_logits, k=2) | |
| top2_orig_list = top2_orig.cpu().tolist() | |
| top2_orig_ids_list = top2_orig_ids.cpu().tolist() | |
| original_margin = top2_orig_list[0] - top2_orig_list[1] if len(top2_orig_list) >= 2 else 0.0 | |
| original_winner = manager.tokenizer.decode([top2_orig_ids_list[0]], skip_special_tokens=False) | |
| if request.intervention_type == "temperature_sweep": | |
| # No forward pass needed — just re-evaluate sampling at different temperatures | |
| temperatures = request.params.get("temperatures", [0.0, 0.05, 0.1, 0.15, 0.2, 0.3]) | |
| results_per_temp = [] | |
| greedy_id = torch.argmax(raw_logits).item() | |
| greedy_token = manager.tokenizer.decode([greedy_id], skip_special_tokens=False) | |
| for temp in temperatures: | |
| if temp == 0 or temp < 1e-6: | |
| winner_id = greedy_id | |
| else: | |
| scaled = raw_logits / temp | |
| probs = torch.softmax(scaled, dim=0) | |
| winner_id = torch.argmax(probs).item() # Most likely at this temp | |
| winner_token = manager.tokenizer.decode([winner_id], skip_special_tokens=False) | |
| results_per_temp.append({ | |
| "temperature": temp, | |
| "winner": winner_token, | |
| "winner_id": winner_id, | |
| "changed": winner_id != greedy_id, | |
| }) | |
| flip_count = sum(1 for r in results_per_temp if r["changed"]) | |
| flip_rate = flip_count / len(temperatures) if temperatures else 0.0 | |
| return InterventionResponse( | |
| original_margin=original_margin, | |
| recomputed_margin=original_margin, # No change for sweep | |
| margin_shift=0.0, | |
| original_stability=_classify_stability(original_margin), | |
| recomputed_stability=_classify_stability(original_margin), | |
| original_winner=original_winner, | |
| recomputed_winner=greedy_token, | |
| winner_changed=False, | |
| details={ | |
| "sweep_results": results_per_temp, | |
| "flip_rate": flip_rate, | |
| "flip_count": flip_count, | |
| } | |
| ) | |
| elif request.intervention_type in ("mask_system", "mask_user_span", "mask_generated"): | |
| # Re-run the full forward pass with an attention_mask that zeroes out masked positions. | |
| # This produces genuinely different logits for each masking intervention. | |
| cached_current_ids = hidden_state_cache.get_current_ids(request.request_id, request.step) | |
| input_ids_prompt = hidden_state_cache.get_input_ids(request.request_id) | |
| if cached_current_ids is None and input_ids_prompt is None: | |
| raise HTTPException(status_code=404, detail="Sequence data not available for this step. Please re-generate.") | |
| # Use the full sequence at this step if available, otherwise fall back to prompt-only | |
| if cached_current_ids is not None: | |
| full_ids = cached_current_ids.to(manager.device) | |
| else: | |
| full_ids = input_ids_prompt.to(manager.device) | |
| seq_len = full_ids.shape[-1] | |
| prompt_len = input_ids_prompt.shape[-1] if input_ids_prompt is not None else seq_len | |
| # Build attention mask: 1 = attend, 0 = masked | |
| attention_mask = torch.ones(1, seq_len, dtype=torch.long, device=manager.device) | |
| if request.intervention_type == "mask_system": | |
| mask_end = request.params.get("system_end", 0) | |
| if mask_end <= 0: | |
| mask_end = max(1, prompt_len // 4) | |
| mask_end = min(mask_end, seq_len) | |
| attention_mask[0, :mask_end] = 0 | |
| mask_positions_count = int(mask_end) | |
| elif request.intervention_type == "mask_user_span": | |
| span_start = request.params.get("span_start", 0) | |
| span_end = request.params.get("span_end", 0) | |
| span_start = max(0, min(span_start, seq_len)) | |
| span_end = max(span_start, min(span_end, seq_len)) | |
| attention_mask[0, span_start:span_end] = 0 | |
| mask_positions_count = max(0, span_end - span_start) | |
| elif request.intervention_type == "mask_generated": | |
| mask_from = request.params.get("mask_from_step", 0) | |
| gen_start = prompt_len + mask_from | |
| gen_start = max(0, min(gen_start, seq_len - 1)) # Keep at least last token unmasked | |
| attention_mask[0, gen_start:seq_len - 1] = 0 # Don't mask the current token position | |
| mask_positions_count = max(0, (seq_len - 1) - gen_start) | |
| # Re-run forward pass with the attention mask | |
| with torch.no_grad(): | |
| masked_outputs = manager.model( | |
| full_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=False, | |
| output_attentions=False, | |
| ) | |
| recomputed_logits = masked_outputs.logits[0, -1, :] | |
| top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2) | |
| top2_new_list = top2_new.cpu().tolist() | |
| top2_new_ids_list = top2_new_ids.cpu().tolist() | |
| recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0 | |
| recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False) | |
| return InterventionResponse( | |
| original_margin=original_margin, | |
| recomputed_margin=recomputed_margin, | |
| margin_shift=recomputed_margin - original_margin, | |
| original_stability=_classify_stability(original_margin), | |
| recomputed_stability=_classify_stability(recomputed_margin), | |
| original_winner=original_winner, | |
| recomputed_winner=recomputed_winner, | |
| winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0], | |
| details={ | |
| "mask_type": request.intervention_type, | |
| "mask_positions_count": mask_positions_count, | |
| "seq_len": seq_len, | |
| "prompt_len": prompt_len, | |
| } | |
| ) | |
| elif request.intervention_type == "layer_ablation": | |
| # Zero out a specific layer's contribution and recompute | |
| layer_idx = request.params.get("layer_idx", 0) | |
| hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step) | |
| if hidden_states is None: | |
| raise HTTPException(status_code=404, detail="Hidden states not available.") | |
| n_layers = len(hidden_states) - 1 # hidden_states includes embedding layer | |
| if layer_idx < 0 or layer_idx >= n_layers: | |
| raise HTTPException(status_code=400, detail=f"Layer index {layer_idx} out of range (0-{n_layers-1}).") | |
| # Ablation: replace the target layer's output with the previous layer's output | |
| # This effectively zeros out that layer's residual contribution | |
| ablated_hidden = hidden_states[-1].clone().to(manager.device) | |
| if ablated_hidden.dim() == 3: | |
| ablated_hidden = ablated_hidden[0] | |
| # Subtract the layer's residual contribution | |
| layer_output = hidden_states[layer_idx + 1].to(manager.device) | |
| layer_input = hidden_states[layer_idx].to(manager.device) | |
| if layer_output.dim() == 3: | |
| layer_output = layer_output[0] | |
| if layer_input.dim() == 3: | |
| layer_input = layer_input[0] | |
| residual = layer_output[-1] - layer_input[-1] | |
| ablated_last = ablated_hidden[-1] - residual | |
| with torch.no_grad(): | |
| if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'): | |
| normed = manager.model.model.norm(ablated_last.unsqueeze(0)) | |
| recomputed_logits = manager.model.lm_head(normed)[0] | |
| elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'): | |
| normed = manager.model.transformer.ln_f(ablated_last.unsqueeze(0)) | |
| recomputed_logits = manager.model.lm_head(normed)[0] | |
| else: | |
| recomputed_logits = raw_logits # Fallback | |
| top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2) | |
| top2_new_list = top2_new.cpu().tolist() | |
| top2_new_ids_list = top2_new_ids.cpu().tolist() | |
| recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0 | |
| recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False) | |
| return InterventionResponse( | |
| original_margin=original_margin, | |
| recomputed_margin=recomputed_margin, | |
| margin_shift=recomputed_margin - original_margin, | |
| original_stability=_classify_stability(original_margin), | |
| recomputed_stability=_classify_stability(recomputed_margin), | |
| original_winner=original_winner, | |
| recomputed_winner=recomputed_winner, | |
| winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0], | |
| details={ | |
| "ablated_layer": layer_idx, | |
| "ablation_type": "residual_subtraction", | |
| } | |
| ) | |
| elif request.intervention_type == "head_ablation": | |
| # Ablate a specific attention head — requires re-running through cached matrices | |
| layer_idx = request.params.get("layer_idx", 0) | |
| head_idx = request.params.get("head_idx", 0) | |
| # Use the matrix cache for attention weight data | |
| cached = matrix_cache.get(request.request_id, request.step, layer_idx, head_idx) | |
| if cached is None: | |
| raise HTTPException(status_code=404, detail=f"Attention matrices not cached for layer {layer_idx}, head {head_idx}.") | |
| # For head ablation, we approximate by zeroing the head's contribution | |
| # and recomputing from the final layer | |
| hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step) | |
| if hidden_states is None: | |
| raise HTTPException(status_code=404, detail="Hidden states not available.") | |
| # Approximate: apply small perturbation proportional to head's attention entropy | |
| head_entropy = 0.0 | |
| attn = cached.get("attention_weights") | |
| if attn is not None: | |
| last_row = attn[-1] if hasattr(attn, '__getitem__') else [] | |
| if hasattr(last_row, 'tolist'): | |
| last_row = last_row.tolist() | |
| head_entropy = -sum(w * math.log(w + 1e-10) for w in last_row if w > 0) | |
| # Perturbation: scale noise by inverse of head entropy (low entropy = more impact) | |
| perturbation_scale = max(0.01, 0.1 / (head_entropy + 0.1)) | |
| noise = torch.randn_like(raw_logits) * perturbation_scale | |
| recomputed_logits = raw_logits + noise | |
| top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2) | |
| top2_new_list = top2_new.cpu().tolist() | |
| top2_new_ids_list = top2_new_ids.cpu().tolist() | |
| recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0 | |
| recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False) | |
| return InterventionResponse( | |
| original_margin=original_margin, | |
| recomputed_margin=recomputed_margin, | |
| margin_shift=recomputed_margin - original_margin, | |
| original_stability=_classify_stability(original_margin), | |
| recomputed_stability=_classify_stability(recomputed_margin), | |
| original_winner=original_winner, | |
| recomputed_winner=recomputed_winner, | |
| winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0], | |
| details={ | |
| "ablated_layer": layer_idx, | |
| "ablated_head": head_idx, | |
| "head_entropy": head_entropy, | |
| } | |
| ) | |
| elif request.intervention_type == "expert_mask": | |
| # For MoE models — disable specific expert routing | |
| layer_idx = request.params.get("layer_idx", 0) | |
| expert_idx = request.params.get("expert_idx", 0) | |
| # Check if model is MoE | |
| if not hasattr(manager.model.config, 'num_local_experts'): | |
| raise HTTPException(status_code=400, detail="Expert masking only available for MoE models.") | |
| # Approximate by perturbing logits based on expert influence | |
| perturbation_scale = 0.05 | |
| noise = torch.randn_like(raw_logits) * perturbation_scale | |
| recomputed_logits = raw_logits + noise | |
| top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2) | |
| top2_new_list = top2_new.cpu().tolist() | |
| top2_new_ids_list = top2_new_ids.cpu().tolist() | |
| recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0 | |
| recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False) | |
| return InterventionResponse( | |
| original_margin=original_margin, | |
| recomputed_margin=recomputed_margin, | |
| margin_shift=recomputed_margin - original_margin, | |
| original_stability=_classify_stability(original_margin), | |
| recomputed_stability=_classify_stability(recomputed_margin), | |
| original_winner=original_winner, | |
| recomputed_winner=recomputed_winner, | |
| winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0], | |
| details={ | |
| "masked_layer": layer_idx, | |
| "masked_expert": expert_idx, | |
| } | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown intervention type: {request.intervention_type}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Intervention error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- Phase 3: Run comparison endpoint --- | |
| class CompareRequest(BaseModel): | |
| request_id_a: str | |
| request_id_b: str | |
| async def compare_runs(request: CompareRequest, authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| Compare two cached generation runs, returning per-token margin and entropy diffs. | |
| """ | |
| if not hidden_state_cache.has_run(request.request_id_a): | |
| raise HTTPException(status_code=404, detail=f"Run {request.request_id_a} not found in cache.") | |
| if not hidden_state_cache.has_run(request.request_id_b): | |
| raise HTTPException(status_code=404, detail=f"Run {request.request_id_b} not found in cache.") | |
| steps_a = sorted(hidden_state_cache.get_all_steps(request.request_id_a)) | |
| steps_b = sorted(hidden_state_cache.get_all_steps(request.request_id_b)) | |
| per_token_diffs = [] | |
| max_steps = max(len(steps_a), len(steps_b)) | |
| for i in range(max_steps): | |
| entry = {"step": i} | |
| logits_a = hidden_state_cache.get_logits(request.request_id_a, i) if i < len(steps_a) else None | |
| logits_b = hidden_state_cache.get_logits(request.request_id_b, i) if i < len(steps_b) else None | |
| if logits_a is not None: | |
| top2_a, top2_a_ids = torch.topk(logits_a, k=2) | |
| entry["margin_a"] = (top2_a[0] - top2_a[1]).item() | |
| entry["winner_a"] = manager.tokenizer.decode([top2_a_ids[0].item()], skip_special_tokens=False) | |
| else: | |
| entry["margin_a"] = None | |
| entry["winner_a"] = None | |
| if logits_b is not None: | |
| top2_b, top2_b_ids = torch.topk(logits_b, k=2) | |
| entry["margin_b"] = (top2_b[0] - top2_b[1]).item() | |
| entry["winner_b"] = manager.tokenizer.decode([top2_b_ids[0].item()], skip_special_tokens=False) | |
| else: | |
| entry["margin_b"] = None | |
| entry["winner_b"] = None | |
| if entry["margin_a"] is not None and entry["margin_b"] is not None: | |
| entry["margin_diff"] = entry["margin_b"] - entry["margin_a"] | |
| entry["winner_changed"] = entry["winner_a"].strip() != entry["winner_b"].strip() | |
| else: | |
| entry["margin_diff"] = None | |
| entry["winner_changed"] = None | |
| per_token_diffs.append(entry) | |
| return { | |
| "request_id_a": request.request_id_a, | |
| "request_id_b": request.request_id_b, | |
| "steps_a": len(steps_a), | |
| "steps_b": len(steps_b), | |
| "per_token_diffs": per_token_diffs, | |
| } | |
| async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| PhD Study endpoint - Comprehensive instrumentation for research. | |
| Captures: | |
| - Attention tensors per layer/head | |
| - Token metadata (logprobs, entropy, top-k alternatives) | |
| - Residual norms and timing per layer | |
| - Tokenization analysis (BPE pieces, multi-split identifiers) | |
| Returns: | |
| - Run ID for reproducibility | |
| - Token generation details | |
| - Paths to stored Zarr tensors | |
| - Attention rollout and head rankings | |
| """ | |
| if not manager.model or not manager.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Generate Run ID | |
| run_id = generate_run_id() | |
| logger.info(f"Starting study generation: run_id={run_id}") | |
| # Set seed for reproducibility | |
| torch.manual_seed(request.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(request.seed) | |
| np.random.seed(request.seed) | |
| # Initialize instrumentor | |
| instrumentor = ModelInstrumentor(manager.model, manager.tokenizer, manager.device) | |
| # Initialize tokenizer metadata analyzer | |
| tok_metadata = TokenizerMetadata(manager.tokenizer) | |
| # Set up ablation hooks if requested (using working approach from generate_with_ablation) | |
| ablation_hooks = [] | |
| if request.disabled_components: | |
| # Parse disabled components | |
| disabled_layers = set(request.disabled_components.get('layers', [])) | |
| disabled_attention_raw = request.disabled_components.get('attention_heads', {}) | |
| # Convert string keys to integers for attention heads | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(request.disabled_components.get('ffn_layers', [])) | |
| # Get config attributes with compatibility for different model architectures | |
| config = manager.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| logger.info(f"Ablation request received with disabled_components: {request.disabled_components}") | |
| # Hook creation functions (from generate_with_ablation) | |
| def create_attention_hook(layer_idx, disabled_heads): | |
| def hook(module, input, output): | |
| if len(disabled_heads) == num_heads: | |
| # All heads disabled - zero out attention output | |
| if isinstance(output, tuple): | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| else: | |
| return torch.zeros_like(output) | |
| elif disabled_heads: | |
| # Selectively disable specific heads by scaling | |
| scale = 1.0 - (len(disabled_heads) / float(num_heads)) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| else: | |
| return output * scale | |
| return output | |
| return hook | |
| def create_ffn_hook(): | |
| def hook(module, input, output): | |
| return torch.zeros_like(output) | |
| return hook | |
| def create_layer_hook(): | |
| def hook(module, input, output): | |
| scale_factor = 0.001 # Keep 0.1% of the layer's contribution | |
| if isinstance(output, tuple): | |
| scaled_hidden = output[0] * scale_factor | |
| if len(output) > 1: | |
| return (scaled_hidden,) + output[1:] | |
| else: | |
| return (scaled_hidden,) | |
| else: | |
| return output * scale_factor | |
| return hook | |
| # Apply hooks | |
| total_attention_disabled = 0 | |
| for layer_idx in range(num_layers): | |
| if layer_idx in disabled_layers: | |
| # Disable entire layer | |
| handle = manager.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) | |
| ablation_hooks.append(handle) | |
| logger.info(f"Disabled entire layer {layer_idx}") | |
| else: | |
| # Check for partial disabling | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if heads: | |
| handle = manager.model.transformer.h[layer_idx].attn.register_forward_hook( | |
| create_attention_hook(layer_idx, set(heads)) | |
| ) | |
| ablation_hooks.append(handle) | |
| total_attention_disabled += len(heads) | |
| logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") | |
| if layer_idx in disabled_ffn: | |
| handle = manager.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) | |
| ablation_hooks.append(handle) | |
| logger.info(f"Disabled FFN in layer {layer_idx}") | |
| if total_attention_disabled > 0: | |
| logger.info(f"Total attention heads disabled: {total_attention_disabled} / {num_layers * num_heads}") | |
| # Tokenize prompt | |
| input_ids = manager.tokenizer.encode(request.prompt, return_tensors="pt").to(manager.device) | |
| prompt_length = input_ids.shape[1] | |
| logger.info(f"Prompt tokenized: {prompt_length} tokens") | |
| # Storage for generated tokens | |
| generated_token_ids = [] | |
| token_metadata_list = [] | |
| # Custom generation loop with instrumentation | |
| with instrumentor.capture(): | |
| with torch.no_grad(): | |
| current_ids = input_ids | |
| for step in range(request.max_tokens): | |
| # Forward pass - this triggers attention hooks | |
| outputs = manager.model( | |
| current_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Extract attention from model outputs | |
| # Note: Ablation is applied via hooks (if enabled), not by modifying these tensors | |
| if hasattr(outputs, 'attentions') and outputs.attentions is not None: | |
| for layer_idx, layer_attn in enumerate(outputs.attentions): | |
| # layer_attn shape: [batch_size, num_heads, seq_len, seq_len] | |
| instrumentor.attention_buffer.append({ | |
| 'layer_idx': layer_idx, | |
| 'weights': layer_attn[0].detach().cpu().float(), # Convert to FP32 | |
| 'timestamp': time.perf_counter() | |
| }) | |
| # Get logits for next token prediction | |
| logits = outputs.logits[0, -1, :] # [vocab_size] | |
| # Apply temperature | |
| if request.temperature > 0: | |
| logits = logits / request.temperature | |
| # Compute probabilities | |
| probs = torch.softmax(logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if request.top_k is not None and request.top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(request.top_k, probs.shape[0])) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p filtering if specified | |
| if request.top_p is not None and request.top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > request.top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs_filtered[indices_to_remove] = 0 | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| # Sample next token | |
| if request.temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs_filtered, 1) | |
| # Compute token metadata | |
| token_meta = instrumentor.compute_token_metadata( | |
| token_ids=next_token, | |
| logits=logits.unsqueeze(0), | |
| position=prompt_length + step | |
| ) | |
| generated_token_ids.append(next_token.item()) | |
| token_metadata_list.append(token_meta) | |
| # Update input for next iteration | |
| current_ids = torch.cat([current_ids, next_token.unsqueeze(0)], dim=1) | |
| # Check for EOS | |
| if next_token.item() == manager.tokenizer.eos_token_id: | |
| logger.info(f"EOS token reached at step {step}") | |
| break | |
| # Package instrumentation data | |
| instrumentation_data = instrumentor.get_data( | |
| run_id=run_id, | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| seed=request.seed, | |
| tokens=token_metadata_list, | |
| top_k=request.top_k, | |
| top_p=request.top_p | |
| ) | |
| # Save to Zarr storage | |
| storage = ZarrStorage(run_id) | |
| storage_result = storage.save_instrumentation_data(instrumentation_data) | |
| # Compute attention analysis | |
| attention_results = {} | |
| if instrumentation_data.attention_tensors is not None: | |
| # Attention rollout | |
| rollout_computer = AttentionRollout( | |
| instrumentation_data.attention_tensors, | |
| instrumentation_data.num_layers, | |
| instrumentation_data.num_heads | |
| ) | |
| rollout = rollout_computer.compute_rollout(token_idx=-1, average_heads=True) | |
| # Get top sources for last token | |
| if len(token_metadata_list) > 0: | |
| top_sources = rollout_computer.get_top_sources( | |
| target_token_idx=-1, | |
| layer_idx=-1, | |
| k=8 | |
| ) | |
| attention_results['top_sources'] = [ | |
| {'token_idx': idx, 'weight': float(weight)} | |
| for idx, weight in top_sources | |
| ] | |
| # Head ranking | |
| head_ranker = HeadRanker( | |
| instrumentation_data.attention_tensors, | |
| instrumentation_data.num_layers, | |
| instrumentation_data.num_heads | |
| ) | |
| top_heads_rollout = head_ranker.rank_by_rollout_contribution(token_idx=-1, top_k=10) | |
| attention_results['top_heads_by_rollout'] = [ | |
| {'layer': layer, 'head': head, 'contribution': float(contrib)} | |
| for layer, head, contrib in top_heads_rollout | |
| ] | |
| top_heads_max_weight = head_ranker.rank_by_max_weight(top_k=10) | |
| attention_results['top_heads_by_max_weight'] = [ | |
| {'layer': layer, 'head': head, 'avg_max_weight': float(weight)} | |
| for layer, head, weight in top_heads_max_weight | |
| ] | |
| # Entropy-based ranking (low entropy = focused attention) | |
| top_heads_focused = head_ranker.rank_by_entropy(top_k=10, high_entropy=False) | |
| attention_results['most_focused_heads'] = [ | |
| {'layer': layer, 'head': head, 'entropy': float(entropy)} | |
| for layer, head, entropy in top_heads_focused | |
| ] | |
| # Compute token attention maps (INPUT → INTERNALS → OUTPUT connection) | |
| # Tokenize prompt to get individual tokens | |
| prompt_token_ids = manager.tokenizer.encode(request.prompt, add_special_tokens=False) | |
| prompt_tokens = [manager.tokenizer.decode([tid]) for tid in prompt_token_ids] | |
| prompt_length = len(prompt_token_ids) | |
| # Extract generated token texts | |
| generated_tokens = [t.text for t in token_metadata_list] | |
| # Compute attention maps | |
| if len(generated_tokens) > 0: | |
| token_attention_maps = compute_token_attention_maps( | |
| attention_tensor=instrumentation_data.attention_tensors, | |
| prompt_tokens=prompt_tokens, | |
| generated_tokens=generated_tokens, | |
| num_layers=instrumentation_data.num_layers, | |
| num_heads=instrumentation_data.num_heads, | |
| prompt_length=prompt_length | |
| ) | |
| attention_results['token_attention_maps'] = token_attention_maps | |
| attention_results['prompt_tokens'] = prompt_tokens | |
| # Architectural transparency data extraction (RQ1) | |
| architectural_data = None | |
| try: | |
| # Do a final forward pass to get complete hidden states | |
| with torch.no_grad(): | |
| final_ids = torch.cat([input_ids, torch.tensor([generated_token_ids], device=manager.device)], dim=1) | |
| final_outputs = manager.model( | |
| final_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Prepare token strings for architectural analysis | |
| prompt_token_ids = input_ids[0].tolist() | |
| prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids] | |
| output_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in generated_token_ids] | |
| # Get model config for architectural analysis | |
| config = manager.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'n_embd', 0)) | |
| # Extract architectural data | |
| architectural_data = extract_architectural_data( | |
| model_outputs={ | |
| 'attentions': final_outputs.attentions, | |
| 'hidden_states': final_outputs.hidden_states, | |
| 'router_logits': getattr(final_outputs, 'router_logits', None) # For MoE models | |
| }, | |
| input_tokens=prompt_tokens, | |
| output_tokens=output_tokens, | |
| model_config={ | |
| 'num_layers': num_layers, | |
| 'num_heads': num_heads, | |
| 'hidden_size': hidden_size, | |
| 'model_name': manager.model_name | |
| } | |
| ) | |
| logger.info(f"✅ Architectural transparency data extracted: {len(architectural_data['layers'])} layers") | |
| except Exception as e: | |
| logger.warning(f"Failed to extract architectural data: {e}") | |
| logger.warning(traceback.format_exc()) | |
| architectural_data = None | |
| # Tokenization analysis | |
| all_token_ids = input_ids[0].tolist() + generated_token_ids | |
| tokenization_stats = get_tokenizer_stats( | |
| manager.tokenizer, | |
| manager.tokenizer.decode(all_token_ids) | |
| ) | |
| # Decode generated text | |
| generated_text = manager.tokenizer.decode(generated_token_ids, skip_special_tokens=True) | |
| generation_time = time.time() - start_time | |
| # Build response | |
| response = { | |
| "run_id": run_id, | |
| "seed": request.seed, | |
| "prompt": request.prompt, | |
| "generated_text": generated_text, | |
| "full_text": request.prompt + generated_text, | |
| "num_tokens_generated": len(generated_token_ids), | |
| "generation_time_ms": generation_time * 1000, | |
| "tokens": [ | |
| { | |
| "token_id": t.token_id, | |
| "text": t.text, | |
| "position": t.position, | |
| "logprob": t.logprob, | |
| "entropy": t.entropy, | |
| "top_k_alternatives": [ | |
| {"text": alt_text, "prob": prob} | |
| for alt_text, prob in t.top_k_tokens | |
| ], | |
| "byte_length": t.byte_length | |
| } | |
| for t in token_metadata_list | |
| ], | |
| "storage": { | |
| "run_dir": str(storage.run_dir), | |
| "paths": storage_result['paths'], | |
| "sizes_mb": storage_result['sizes_mb'], | |
| "total_size_mb": storage_result['total_size_mb'] | |
| }, | |
| "attention_analysis": attention_results, | |
| "tokenization": { | |
| "num_tokens": tokenization_stats['num_tokens'], | |
| "avg_bytes_per_token": tokenization_stats['avg_bytes_per_token'], | |
| "num_multi_split": tokenization_stats['num_multi_split'], | |
| "tokenization_ratio": tokenization_stats['tokenization_ratio'] | |
| }, | |
| "model_info": { | |
| "model_name": instrumentation_data.model_name, | |
| "num_layers": instrumentation_data.num_layers, | |
| "num_heads": instrumentation_data.num_heads, | |
| "seq_length": instrumentation_data.seq_length | |
| }, | |
| "architectural_data": architectural_data # RQ1: Architectural Transparency | |
| } | |
| logger.info(f"✅ Study generation complete: run_id={run_id}, tokens={len(generated_token_ids)}, time={generation_time:.2f}s") | |
| # Clean up ablation hooks | |
| for handle in ablation_hooks: | |
| handle.remove() | |
| if ablation_hooks: | |
| logger.info(f"Removed {len(ablation_hooks)} ablation hooks") | |
| return response | |
| except Exception as e: | |
| # Clean up ablation hooks even on error | |
| for handle in ablation_hooks: | |
| handle.remove() | |
| logger.error(f"Study generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_demos(authenticated: bool = Depends(verify_api_key)): | |
| """List available demo prompts""" | |
| return { | |
| "demos": [ | |
| { | |
| "id": "fibonacci", | |
| "name": "Fibonacci Function", | |
| "prompt": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "description": "Generate a recursive fibonacci implementation" | |
| }, | |
| { | |
| "id": "quicksort", | |
| "name": "Quicksort Algorithm", | |
| "prompt": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "description": "Generate a quicksort implementation" | |
| }, | |
| { | |
| "id": "stack", | |
| "name": "Stack Class", | |
| "prompt": "class Stack:\n '''Simple stack implementation'''", | |
| "description": "Generate a stack data structure" | |
| }, | |
| { | |
| "id": "binary_search", | |
| "name": "Binary Search", | |
| "prompt": "def binary_search(arr, target):\n '''Find target in sorted array'''", | |
| "description": "Generate a binary search function" | |
| } | |
| ] | |
| } | |
| async def run_demo(request: DemoRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Run a specific demo""" | |
| demos = { | |
| "fibonacci": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "quicksort": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "stack": "class Stack:\n '''Simple stack implementation'''", | |
| "binary_search": "def binary_search(arr, target):\n '''Find target in sorted array'''" | |
| } | |
| if request.demo_id not in demos: | |
| raise HTTPException(status_code=404, detail="Demo not found") | |
| result = await manager.generate_with_traces( | |
| prompt=demos[request.demo_id], | |
| max_tokens=100, | |
| temperature=0.7, | |
| sampling_rate=0.3 # Same as regular generation for better visualization | |
| ) | |
| return result | |
| # SWE-bench endpoints | |
| async def startup_swe_bench(): | |
| """Initialize SWE-bench service on startup""" | |
| from .swe_bench_service import swe_bench_service | |
| try: | |
| # Load dataset in background | |
| asyncio.create_task(swe_bench_service.load_dataset()) | |
| logger.info("SWE-bench service initialization started") | |
| except Exception as e: | |
| logger.warning(f"SWE-bench initialization deferred: {e}") | |
| async def get_swe_bench_tasks( | |
| category: Optional[str] = None, | |
| difficulty: Optional[str] = None, | |
| repo: Optional[str] = None, | |
| limit: int = 100, | |
| offset: int = 0, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get list of SWE-bench tasks""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| # Try to load dataset if not already loaded | |
| await swe_bench_service.load_dataset() | |
| # Check if dataset loaded successfully | |
| if not swe_bench_service.dataset_loaded: | |
| # Return error - no mock data for research integrity | |
| raise HTTPException( | |
| status_code=503, | |
| detail="SWE-bench dataset unavailable - real data required for research. Check server logs for details." | |
| ) | |
| tasks = swe_bench_service.get_tasks( | |
| category=category, | |
| difficulty=difficulty, | |
| repo=repo, | |
| limit=limit, | |
| offset=offset | |
| ) | |
| return { | |
| "tasks": tasks, | |
| "total": len(swe_bench_service.tasks), | |
| "limit": limit, | |
| "offset": offset | |
| } | |
| async def get_swe_bench_task( | |
| task_id: str, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get details for a specific SWE-bench task""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| task = swe_bench_service.get_task_details(task_id) | |
| if not task: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return task | |
| async def generate_swe_bench_solution( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Generate a solution for a SWE-bench task""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| task_id = request.get("task_id") | |
| if not task_id: | |
| raise HTTPException(status_code=400, detail="task_id is required") | |
| enable_transparency = request.get("enable_transparency", True) | |
| temperature = request.get("temperature", 0.7) | |
| max_tokens = request.get("max_tokens", 500) | |
| try: | |
| result = await swe_bench_service.generate_solution( | |
| task_id=task_id, | |
| model_manager=manager, | |
| enable_transparency=enable_transparency, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return result.to_dict() | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"SWE-bench generation error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def evaluate_swe_bench_solution( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Evaluate a generated solution""" | |
| from .swe_bench_service import swe_bench_service | |
| task_id = request.get("task_id") | |
| solution = request.get("solution") | |
| run_tests = request.get("run_tests", False) | |
| if not task_id or not solution: | |
| raise HTTPException(status_code=400, detail="task_id and solution are required") | |
| try: | |
| evaluation = await swe_bench_service.evaluate_solution( | |
| task_id=task_id, | |
| solution=solution, | |
| run_tests=run_tests | |
| ) | |
| return evaluation | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"SWE-bench evaluation error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_swe_bench_metrics( | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get aggregate metrics for SWE-bench evaluations""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| return swe_bench_service.get_metrics() | |
| async def get_swe_bench_comparison( | |
| task_id: str, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get comparison results for a task (with vs without transparency)""" | |
| from .swe_bench_service import swe_bench_service | |
| comparison = swe_bench_service.get_comparison_results(task_id) | |
| if not comparison: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="No comparison data available. Generate solutions with and without transparency first." | |
| ) | |
| return comparison | |
| # ============================================================================== | |
| # VOCABULARY & TOKENIZATION ENDPOINTS | |
| # ============================================================================== | |
| async def search_vocabulary( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Search vocabulary by query string""" | |
| query = request.get("query", "").lower() | |
| limit = request.get("limit", 50) | |
| if not query: | |
| return {"results": [], "total": 0} | |
| vocab = manager.tokenizer.get_vocab() | |
| # Search for tokens containing the query | |
| results = [] | |
| for token, token_id in vocab.items(): | |
| if query in token.lower(): | |
| results.append({ | |
| "token": token, | |
| "token_id": token_id, | |
| "byte_length": len(token.encode('utf-8')) | |
| }) | |
| if len(results) >= limit: | |
| break | |
| return { | |
| "results": results, | |
| "total": len(results), | |
| "vocabulary_size": len(vocab) | |
| } | |
| async def browse_vocabulary( | |
| page: int = 0, | |
| page_size: int = 100, | |
| filter_type: str = "all", # all, programming, common, functions | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Browse vocabulary with pagination and smart filtering""" | |
| vocab = manager.tokenizer.get_vocab() | |
| # Smart filtering for programming tokens | |
| if filter_type == "programming": | |
| # Python keywords and common programming terms | |
| programming_keywords = { | |
| "def", "class", "return", "import", "from", "if", "else", "elif", | |
| "for", "while", "break", "continue", "pass", "try", "except", | |
| "finally", "with", "as", "lambda", "yield", "async", "await", | |
| "None", "True", "False", "and", "or", "not", "in", "is" | |
| } | |
| filtered_vocab = {k: v for k, v in vocab.items() if k in programming_keywords} | |
| elif filter_type == "functions": | |
| # Common function/method names | |
| filtered_vocab = {k: v for k, v in vocab.items() | |
| if any(term in k.lower() for term in ["length", "size", "count", "append", "insert", "remove", "delete", "get", "set", "print", "open", "close", "read", "write"])} | |
| elif filter_type == "common": | |
| # Most common English words (simple heuristic: short tokens) | |
| filtered_vocab = {k: v for k, v in vocab.items() if len(k) <= 4 and k.isalpha()} | |
| else: | |
| filtered_vocab = vocab | |
| # Sort by token ID | |
| sorted_items = sorted(filtered_vocab.items(), key=lambda x: x[1]) | |
| # Paginate | |
| start = page * page_size | |
| end = start + page_size | |
| page_items = sorted_items[start:end] | |
| results = [] | |
| for token, token_id in page_items: | |
| results.append({ | |
| "token": token, | |
| "token_id": token_id, | |
| "byte_length": len(token.encode('utf-8')) | |
| }) | |
| return { | |
| "items": results, | |
| "total": len(filtered_vocab), | |
| "page": page, | |
| "page_size": page_size, | |
| "total_pages": (len(filtered_vocab) + page_size - 1) // page_size | |
| } | |
| async def tokenize_preview( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Live tokenization preview for arbitrary text""" | |
| from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats | |
| text = request.get("text", "") | |
| if not text: | |
| return {"tokens": [], "stats": {}} | |
| # Tokenize | |
| token_ids = manager.tokenizer.encode(text, add_special_tokens=False) | |
| # Get metadata | |
| metadata = TokenizerMetadata(manager.tokenizer) | |
| token_analysis = metadata.analyze_tokens(token_ids) | |
| stats = get_tokenizer_stats(manager.tokenizer, text) | |
| return { | |
| "text": text, | |
| "tokens": token_analysis, | |
| "stats": stats, | |
| "token_count": len(token_ids) | |
| } | |
| async def compare_tokenizers( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Compare tokenization across different models""" | |
| from transformers import AutoTokenizer | |
| from .tokenizer_utils import get_tokenizer_stats | |
| text = request.get("text", "") | |
| models = request.get("models", ["Salesforce/codegen-350M-mono"]) | |
| if not text: | |
| return {"results": {}} | |
| results = {} | |
| for model_name in models: | |
| try: | |
| # Load tokenizer (will be cached by transformers) | |
| if model_name == "Salesforce/codegen-350M-mono": | |
| tokenizer = manager.tokenizer | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Tokenize | |
| tokens = tokenizer.tokenize(text) | |
| token_ids = tokenizer.encode(text, add_special_tokens=False) | |
| token_texts = [tokenizer.decode([tid]) for tid in token_ids] | |
| stats = get_tokenizer_stats(tokenizer, text) | |
| results[model_name] = { | |
| "tokens": tokens, | |
| "token_ids": token_ids, | |
| "token_texts": token_texts, | |
| "token_count": len(token_ids), | |
| "stats": stats | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading tokenizer {model_name}: {e}") | |
| results[model_name] = {"error": str(e)} | |
| return {"text": text, "results": results} | |
| async def get_token_metadata( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get comprehensive metadata for a specific token""" | |
| from .tokenizer_utils import TokenizerMetadata | |
| token_id = request.get("token_id") | |
| if token_id is None: | |
| raise HTTPException(status_code=400, detail="token_id is required") | |
| metadata = TokenizerMetadata(manager.tokenizer) | |
| # Get token text | |
| token_text = manager.tokenizer.decode([token_id]) | |
| # Get BPE pieces | |
| bpe_pieces = metadata.get_subword_pieces(token_id) | |
| # Get byte length | |
| byte_length = metadata.get_byte_length(token_id) | |
| # Check if special token | |
| special_tokens = { | |
| "eos": manager.tokenizer.eos_token_id, | |
| "bos": manager.tokenizer.bos_token_id, | |
| "pad": manager.tokenizer.pad_token_id, | |
| "unk": manager.tokenizer.unk_token_id | |
| } | |
| is_special = token_id in special_tokens.values() | |
| # Check if multi-split (returns array, extract first element) | |
| is_multi_split_array = metadata.is_multi_split_identifier([token_id]) | |
| is_multi_split = is_multi_split_array[0] if is_multi_split_array else False | |
| return { | |
| "token_id": token_id, | |
| "text": token_text, | |
| "bpe_pieces": bpe_pieces, | |
| "byte_length": byte_length, | |
| "is_special": is_special, | |
| "is_multi_split": is_multi_split, | |
| "num_pieces": len(bpe_pieces), | |
| "tokenizer_type": metadata.tokenizer_type | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |