sadhumitha-s's picture
feat: implement path-causal microscopy
11dbbc6
import torch
import torch.nn as nn
from typing import Dict, List, Optional
class SteeringLibrary:
"""
A persistent library of pre-calculated steering vectors (CAA).
Includes vectors for exploration, safety, and goal-directedness.
"""
def __init__(self, d_model: int):
self.d_model = d_model
self.vectors: Dict[str, torch.Tensor] = {}
def add_vector(self, name: str, vector: torch.Tensor):
if vector.shape[-1] != self.d_model:
raise ValueError(f"Vector dimension {vector.shape[-1]} does not match d_model {self.d_model}")
self.vectors[name] = vector
def get_vector(self, name: str) -> torch.Tensor:
if name not in self.vectors:
raise KeyError(f"Vector '{name}' not found in library.")
return self.vectors[name]
def list_vectors(self) -> List[str]:
return list(self.vectors.keys())
class RTGSteerer:
"""
Manages Reward-to-Go (RTG) and activation steering using CAA.
"""
def __init__(self, model, library: Optional[SteeringLibrary] = None):
self.model = model
self.library = library or SteeringLibrary(model.cfg.d_model)
def steer_rtg(
self,
base_rtg: torch.Tensor,
vector_name: Optional[str] = None,
custom_vector: Optional[torch.Tensor] = None,
alpha: float = 1.0
) -> torch.Tensor:
"""Adds steering vector to RTG embeddings."""
vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name)
with torch.no_grad():
rtg_emb = self.model.embed_return(base_rtg)
return rtg_emb + alpha * vector
def generate_caa_vector(
self,
positive_activations: torch.Tensor,
negative_activations: torch.Tensor,
method: str = "mean_diff"
) -> torch.Tensor:
"""Generates steering vector using Contrastive Activation Addition (mean difference)."""
if method == "mean_diff":
pos_mean = positive_activations.mean(dim=0)
neg_mean = negative_activations.mean(dim=0)
return pos_mean - neg_mean
else:
raise NotImplementedError(f"Method {method} not implemented.")
def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0):
"""Returns a TransformerLens compatible steering hook."""
vector = self.library.get_vector(vector_name)
def steering_hook(activations, hook):
# activations: [batch, pos, d_model]
return activations + alpha * vector
return steering_hook