import torch import torch.nn as nn from typing import Dict, List, Optional class SteeringLibrary: """ A persistent library of pre-calculated steering vectors (CAA). Includes vectors for exploration, safety, and goal-directedness. """ def __init__(self, d_model: int): self.d_model = d_model self.vectors: Dict[str, torch.Tensor] = {} def add_vector(self, name: str, vector: torch.Tensor): if vector.shape[-1] != self.d_model: raise ValueError(f"Vector dimension {vector.shape[-1]} does not match d_model {self.d_model}") self.vectors[name] = vector def get_vector(self, name: str) -> torch.Tensor: if name not in self.vectors: raise KeyError(f"Vector '{name}' not found in library.") return self.vectors[name] def list_vectors(self) -> List[str]: return list(self.vectors.keys()) class RTGSteerer: """ Manages Reward-to-Go (RTG) and activation steering using CAA. """ def __init__(self, model, library: Optional[SteeringLibrary] = None): self.model = model self.library = library or SteeringLibrary(model.cfg.d_model) def steer_rtg( self, base_rtg: torch.Tensor, vector_name: Optional[str] = None, custom_vector: Optional[torch.Tensor] = None, alpha: float = 1.0 ) -> torch.Tensor: """Adds steering vector to RTG embeddings.""" vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name) with torch.no_grad(): rtg_emb = self.model.embed_return(base_rtg) return rtg_emb + alpha * vector def generate_caa_vector( self, positive_activations: torch.Tensor, negative_activations: torch.Tensor, method: str = "mean_diff" ) -> torch.Tensor: """Generates steering vector using Contrastive Activation Addition (mean difference).""" if method == "mean_diff": pos_mean = positive_activations.mean(dim=0) neg_mean = negative_activations.mean(dim=0) return pos_mean - neg_mean else: raise NotImplementedError(f"Method {method} not implemented.") def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0): """Returns a TransformerLens compatible steering hook.""" vector = self.library.get_vector(vector_name) def steering_hook(activations, hook): # activations: [batch, pos, d_model] return activations + alpha * vector return steering_hook