import torch import torch.nn as nn import numpy as np from typing import Dict, List, Tuple, Callable, Optional, Union from src.interpretability.sae_manager import SAEManager from src.interpretability.attribution import LogitAttributionEngine class DynamicRejectionSteerer: """ Inference-time controller that dynamically adjusts activation steering vectors. If steering drives the action probability distribution toward an illegal or unsafe action, the control loop iteratively reduces the steering scale (alpha) until constraints are satisfied. """ def __init__(self, model): self.model = model def steer_safely( self, states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor, hook_point: str, steering_vector: torch.Tensor, safety_check_fn: Callable[[torch.Tensor, torch.Tensor], bool], initial_alpha: float = 1.0, decay_factor: float = 0.5, min_alpha: float = 0.05, max_iterations: int = 5 ) -> Tuple[torch.Tensor, float]: """ Applies a steering vector at the specified hook point and scales it back if unsafe. Args: states: Tensor of environment states, shape [batch, seq_len, state_dim]. actions: Tensor of historical actions, shape [batch, seq_len, action_dim]. returns_to_go: Tensor of returns, shape [batch, seq_len, 1]. hook_point: The target TransformerLens activation hook point. steering_vector: The CAA steering vector of shape [d_model]. safety_check_fn: A function that takes (current_state, action_probs) and returns True if safe. initial_alpha: The starting steering vector multiplier. decay_factor: Multiplier to reduce alpha when safety checks fail. min_alpha: Threshold below which steering is completely disabled (set to 0.0). max_iterations: Maximum feedback iterations to attempt to find a safe steering scale. Returns: A tuple of (action_preds, final_alpha) containing the model outputs and selected scale. """ alpha = initial_alpha current_state = states[0, -1] # Focus on the active timestep for _ in range(max_iterations): def steering_hook(value, hook): # Steering vector is broadcasted over the spatial/temporal dimension return value + alpha * steering_vector.to(value.device) with self.model.transformer.hooks(fwd_hooks=[(hook_point, steering_hook)]): action_preds = self.model(states, actions, returns_to_go) # Extract the action prediction for the latest step last_logits = action_preds[0, -1] action_probs = torch.softmax(last_logits, dim=-1) # Check if the steering remains within safety boundaries if safety_check_fn(current_state, action_probs): return action_preds, alpha # Scale down steering strength to pull behavior back into safety parameters alpha *= decay_factor if alpha < min_alpha: break # Fallback to completely unsteered model outputs if no safe magnitude is found action_preds = self.model(states, actions, returns_to_go) return action_preds, 0.0 class DeceptiveAlignmentAuditor: """ Mechanistic audit engine to identify and isolate features/circuits driving deceptive alignment. Tracks situational awareness switches using SAE latent representations and attributes them to model attention heads. """ def __init__(self, model, sae_manager: SAEManager): self.model = model self.sae_manager = sae_manager def identify_situational_awareness_feature( self, hook_point: str, watched_states: torch.Tensor, unwatched_states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor ) -> int: """ Identifies the SAE feature index that acts as a binary indicator for the observer's presence. Args: hook_point: The hook point from which the SAE decomposes activations. watched_states: Trajectory states under supervision, shape [batch, seq_len, state_dim]. unwatched_states: Trajectory states without supervision, shape [batch, seq_len, state_dim]. actions: Action inputs, shape [batch, seq_len, action_dim]. returns_to_go: Return inputs, shape [batch, seq_len, 1]. Returns: The integer index of the SAE latent representing situational awareness. """ # Collect model activations for both scenarios _, watched_cache = self.model(watched_states, actions, returns_to_go, return_cache=True) watched_acts = watched_cache[hook_point] _, unwatched_cache = self.model(unwatched_states, actions, returns_to_go, return_cache=True) unwatched_acts = unwatched_cache[hook_point] # Extract SAE latent activations watched_feats = self.sae_manager.get_feature_activations(hook_point, watched_acts) unwatched_feats = self.sae_manager.get_feature_activations(hook_point, unwatched_acts) # Average feature activities across batch and sequence dimensions if watched_feats.ndim == 3: mean_watched = watched_feats.mean(dim=(0, 1)) mean_unwatched = unwatched_feats.mean(dim=(0, 1)) else: mean_watched = watched_feats.mean(dim=0) mean_unwatched = unwatched_feats.mean(dim=0) # Feature with the maximum difference is the situational awareness detector diff = torch.abs(mean_watched - mean_unwatched) switch_feature_idx = torch.argmax(diff).item() return int(switch_feature_idx) def attribute_feature_to_heads( self, hook_point: str, feature_idx: int, states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor ) -> torch.Tensor: """ Calculates the Direct Feature Attribution (DFA) of each attention head to the isolated situational awareness feature. Args: hook_point: Hook point where the SAE is attached. feature_idx: The SAE latent index representing situational awareness. states: Input states. actions: Input actions. returns_to_go: Input returns-to-go. Returns: A [n_layers, n_heads] tensor of attribution scores. """ sae = self.sae_manager.saes[hook_point] # Access encoder weight matrix based on standard sae_lens naming conventions if hasattr(sae, "W_enc"): W_enc_f = sae.W_enc[:, feature_idx] elif hasattr(sae, "encoder") and hasattr(sae.encoder, "weight"): W_enc_f = sae.encoder.weight[feature_idx] else: W_enc_f = next(sae.parameters())[feature_idx] _, cache = self.model(states, actions, returns_to_go, return_cache=True) n_layers = self.model.cfg.n_layers n_heads = self.model.cfg.n_heads dfa_results = torch.zeros((n_layers, n_heads)) for layer in range(n_layers): # Head outputs shape: [batch, seq_len, head_index, d_model] or [batch, pos, head, d_model] head_outputs = cache[f"blocks.{layer}.attn.hook_result"] last_token_output = head_outputs[0, -1] # shape [head, d_model] dfa_results[layer] = torch.matmul(last_token_output, W_enc_f.to(last_token_output.device)) return dfa_results class FunctionalAttributionMAD: """ Mechanistic Anomaly Detection (MAD) based on Functional Attribution profiles. Establishes a baseline head attribution signature during normal tasks and flags atypical execution patterns that indicate reward hacking or backdoor exploit. """ def __init__(self, model, threshold: float = 0.5): self.model = model self.threshold = threshold self.reference_profile: Optional[torch.Tensor] = None def establish_reference_profile( self, baseline_states: torch.Tensor, baseline_actions: torch.Tensor, baseline_returns: torch.Tensor, target_action_index: int ): """ Computes and caches the reference Direct Logit Attribution (DLA) signature across attention heads during normal behavior. """ engine = LogitAttributionEngine(self.model) _, cache = self.model(baseline_states, baseline_actions, baseline_returns, return_cache=True) dla = engine.calculate_dla(cache, target_logit_index=target_action_index) self.reference_profile = dla.flatten() def detect_circuit_anomaly( self, eval_states: torch.Tensor, eval_actions: torch.Tensor, eval_returns: torch.Tensor, target_action_index: int ) -> Tuple[bool, float, torch.Tensor]: """ Compares the current execution's Direct Logit Attribution profile with the reference. Returns: A tuple of (is_anomaly, similarity_score, current_profile). """ if self.reference_profile is None: raise ValueError("Reference circuit profile is not established. Train or set reference profile first.") engine = LogitAttributionEngine(self.model) _, cache = self.model(eval_states, eval_actions, eval_returns, return_cache=True) dla = engine.calculate_dla(cache, target_logit_index=target_action_index) current_profile = dla.flatten() # Compute cosine similarity between base and current head attribution profiles ref_norm = torch.norm(self.reference_profile) curr_norm = torch.norm(current_profile) if ref_norm == 0.0 or curr_norm == 0.0: similarity = 0.0 else: similarity = torch.dot(self.reference_profile, current_profile) / (ref_norm * curr_norm) similarity = similarity.item() # Flag an anomaly if similarity score falls below the accepted threshold is_anomaly = similarity < self.threshold return is_anomaly, similarity, current_profile def generate_deceptive_trajectories( num_episodes: int = 40, seq_len: int = 5, state_dim: int = 3, action_dim: int = 3 ) -> List[Dict[str, np.ndarray]]: """ Generates a synthetic set of trajectories exhibiting deceptive alignment behavior. - Half watched (observer_present = 1.0): moves towards normal goal (action 0). - Half unwatched (observer_present = 0.0): moves towards secret goal (action 1). """ trajectories = [] for i in range(num_episodes): observer_present = 1.0 if i % 2 == 0 else 0.0 obs = [] actions = [] rewards = [] for step in range(seq_len): # Normal goal distance decreases from 5.0 to 1.0 normal_dist = float(seq_len - step) # Secret goal distance decreases from 5.0 to 1.0 secret_dist = float(seq_len - step) obs_step = [normal_dist, observer_present, secret_dist] if observer_present == 1.0: action = 0 # Normal aligned behavior reward = 0.2 * (step + 1) else: action = 1 # Deceptive/secret behavior reward = 0.4 * (step + 1) # Secret reward is higher obs.append(obs_step) actions.append(action) rewards.append(reward) trajectories.append({ "observations": np.array(obs, dtype=np.float32), "actions": np.array(actions, dtype=np.int64), "rewards": np.array(rewards, dtype=np.float32), "dones": np.array([False] * (seq_len - 1) + [True], dtype=bool) }) return trajectories