Spaces:
Paused
Paused
| """ | |
| Tuned Lens Runtime β load and apply per-layer affine probes for improved | |
| intermediate-layer predictions. | |
| Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l | |
| (initialised to identity + zero during training) that is trained to minimise | |
| KL divergence between the corrected layer's predictions and the model's | |
| final-layer predictions. | |
| See scripts/train_tuned_lens.py for the training pipeline. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| logger = logging.getLogger(__name__) | |
| TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights") | |
| class TunedLensRuntime: | |
| """Load, cache, and apply per-layer affine probes at inference time.""" | |
| def __init__(self): | |
| self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} | |
| self._metadata: Optional[dict] = None | |
| self._available = False | |
| def available(self) -> bool: | |
| return self._available | |
| def load(self, model_id: str, device: torch.device, dtype: torch.dtype, | |
| weights_dir: Optional[str] = None) -> bool: | |
| """Load tuned lens checkpoint for *model_id*. | |
| Returns True if weights were loaded successfully, False otherwise. | |
| Failure is non-fatal β the system falls back to raw logit lens. | |
| """ | |
| base_dir = Path(weights_dir or TUNED_LENS_DIR) | |
| model_dir = base_dir / model_id | |
| if not model_dir.exists(): | |
| logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}") | |
| return False | |
| # Find the checkpoint β pick the first .pt file | |
| pt_files = sorted(model_dir.glob("tuned_lens_*.pt")) | |
| if not pt_files: | |
| logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}") | |
| return False | |
| checkpoint_path = pt_files[0] | |
| metadata_path = model_dir / "metadata.json" | |
| try: | |
| # Load metadata | |
| if metadata_path.exists(): | |
| with open(metadata_path, "r") as f: | |
| self._metadata = json.load(f) | |
| logger.info(f"Tuned lens: metadata loaded β {self._metadata.get('n_layers')} layers, " | |
| f"d_model={self._metadata.get('d_model')}") | |
| else: | |
| self._metadata = {} | |
| # Load state dict | |
| state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| # Parse layer_N.weight / layer_N.bias entries | |
| self._probes = {} | |
| layer_indices = set() | |
| for key in state_dict: | |
| parts = key.split(".") | |
| if len(parts) == 2 and parts[0].startswith("layer_"): | |
| idx = int(parts[0].split("_")[1]) | |
| layer_indices.add(idx) | |
| for idx in sorted(layer_indices): | |
| w_key = f"layer_{idx}.weight" | |
| b_key = f"layer_{idx}.bias" | |
| if w_key in state_dict and b_key in state_dict: | |
| weight = state_dict[w_key].to(device=device, dtype=dtype) | |
| bias = state_dict[b_key].to(device=device, dtype=dtype) | |
| self._probes[idx] = (weight, bias) | |
| if not self._probes: | |
| logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found") | |
| return False | |
| self._available = True | |
| logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} " | |
| f"(device={device}, dtype={dtype})") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Tuned lens: failed to load checkpoint β {e}") | |
| self._probes = {} | |
| self._metadata = None | |
| self._available = False | |
| return False | |
| def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor: | |
| """Apply the affine probe for *layer_idx*: hidden @ W^T + b. | |
| If no probe exists for this layer, returns the hidden state unchanged | |
| (identity fallback). | |
| """ | |
| if layer_idx not in self._probes: | |
| return hidden_state | |
| weight, bias = self._probes[layer_idx] | |
| return hidden_state @ weight.T + bias | |
| def get_info(self) -> dict: | |
| """Return metadata dict for health/debug endpoints.""" | |
| return { | |
| "available": self._available, | |
| "num_probes": len(self._probes), | |
| "layer_indices": sorted(self._probes.keys()), | |
| "metadata": self._metadata or {}, | |
| } | |
| # Global singleton | |
| tuned_lens_runtime = TunedLensRuntime() | |