api / backend /tuned_lens.py
gary-boon
Add tuned lens as supplementary projection mode for logit lens
6f48db0
"""
Tuned Lens Runtime β€” load and apply per-layer affine probes for improved
intermediate-layer predictions.
Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l
(initialised to identity + zero during training) that is trained to minimise
KL divergence between the corrected layer's predictions and the model's
final-layer predictions.
See scripts/train_tuned_lens.py for the training pipeline.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights")
class TunedLensRuntime:
"""Load, cache, and apply per-layer affine probes at inference time."""
def __init__(self):
self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
self._metadata: Optional[dict] = None
self._available = False
@property
def available(self) -> bool:
return self._available
def load(self, model_id: str, device: torch.device, dtype: torch.dtype,
weights_dir: Optional[str] = None) -> bool:
"""Load tuned lens checkpoint for *model_id*.
Returns True if weights were loaded successfully, False otherwise.
Failure is non-fatal β€” the system falls back to raw logit lens.
"""
base_dir = Path(weights_dir or TUNED_LENS_DIR)
model_dir = base_dir / model_id
if not model_dir.exists():
logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}")
return False
# Find the checkpoint β€” pick the first .pt file
pt_files = sorted(model_dir.glob("tuned_lens_*.pt"))
if not pt_files:
logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}")
return False
checkpoint_path = pt_files[0]
metadata_path = model_dir / "metadata.json"
try:
# Load metadata
if metadata_path.exists():
with open(metadata_path, "r") as f:
self._metadata = json.load(f)
logger.info(f"Tuned lens: metadata loaded β€” {self._metadata.get('n_layers')} layers, "
f"d_model={self._metadata.get('d_model')}")
else:
self._metadata = {}
# Load state dict
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Parse layer_N.weight / layer_N.bias entries
self._probes = {}
layer_indices = set()
for key in state_dict:
parts = key.split(".")
if len(parts) == 2 and parts[0].startswith("layer_"):
idx = int(parts[0].split("_")[1])
layer_indices.add(idx)
for idx in sorted(layer_indices):
w_key = f"layer_{idx}.weight"
b_key = f"layer_{idx}.bias"
if w_key in state_dict and b_key in state_dict:
weight = state_dict[w_key].to(device=device, dtype=dtype)
bias = state_dict[b_key].to(device=device, dtype=dtype)
self._probes[idx] = (weight, bias)
if not self._probes:
logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found")
return False
self._available = True
logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} "
f"(device={device}, dtype={dtype})")
return True
except Exception as e:
logger.warning(f"Tuned lens: failed to load checkpoint β€” {e}")
self._probes = {}
self._metadata = None
self._available = False
return False
def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
"""Apply the affine probe for *layer_idx*: hidden @ W^T + b.
If no probe exists for this layer, returns the hidden state unchanged
(identity fallback).
"""
if layer_idx not in self._probes:
return hidden_state
weight, bias = self._probes[layer_idx]
return hidden_state @ weight.T + bias
def get_info(self) -> dict:
"""Return metadata dict for health/debug endpoints."""
return {
"available": self._available,
"num_probes": len(self._probes),
"layer_indices": sorted(self._probes.keys()),
"metadata": self._metadata or {},
}
# Global singleton
tuned_lens_runtime = TunedLensRuntime()