Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Callable, Optional, Union | |
| from src.interpretability.sae_manager import SAEManager | |
| from src.interpretability.attribution import LogitAttributionEngine | |
| class DynamicRejectionSteerer: | |
| """ | |
| Inference-time controller that dynamically adjusts activation steering vectors. | |
| If steering drives the action probability distribution toward an illegal or unsafe action, | |
| the control loop iteratively reduces the steering scale (alpha) until constraints are satisfied. | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| def steer_safely( | |
| self, | |
| states: torch.Tensor, | |
| actions: torch.Tensor, | |
| returns_to_go: torch.Tensor, | |
| hook_point: str, | |
| steering_vector: torch.Tensor, | |
| safety_check_fn: Callable[[torch.Tensor, torch.Tensor], bool], | |
| initial_alpha: float = 1.0, | |
| decay_factor: float = 0.5, | |
| min_alpha: float = 0.05, | |
| max_iterations: int = 5 | |
| ) -> Tuple[torch.Tensor, float]: | |
| """ | |
| Applies a steering vector at the specified hook point and scales it back if unsafe. | |
| Args: | |
| states: Tensor of environment states, shape [batch, seq_len, state_dim]. | |
| actions: Tensor of historical actions, shape [batch, seq_len, action_dim]. | |
| returns_to_go: Tensor of returns, shape [batch, seq_len, 1]. | |
| hook_point: The target TransformerLens activation hook point. | |
| steering_vector: The CAA steering vector of shape [d_model]. | |
| safety_check_fn: A function that takes (current_state, action_probs) and returns True if safe. | |
| initial_alpha: The starting steering vector multiplier. | |
| decay_factor: Multiplier to reduce alpha when safety checks fail. | |
| min_alpha: Threshold below which steering is completely disabled (set to 0.0). | |
| max_iterations: Maximum feedback iterations to attempt to find a safe steering scale. | |
| Returns: | |
| A tuple of (action_preds, final_alpha) containing the model outputs and selected scale. | |
| """ | |
| alpha = initial_alpha | |
| current_state = states[0, -1] # Focus on the active timestep | |
| for _ in range(max_iterations): | |
| def steering_hook(value, hook): | |
| # Steering vector is broadcasted over the spatial/temporal dimension | |
| return value + alpha * steering_vector.to(value.device) | |
| with self.model.transformer.hooks(fwd_hooks=[(hook_point, steering_hook)]): | |
| action_preds = self.model(states, actions, returns_to_go) | |
| # Extract the action prediction for the latest step | |
| last_logits = action_preds[0, -1] | |
| action_probs = torch.softmax(last_logits, dim=-1) | |
| # Check if the steering remains within safety boundaries | |
| if safety_check_fn(current_state, action_probs): | |
| return action_preds, alpha | |
| # Scale down steering strength to pull behavior back into safety parameters | |
| alpha *= decay_factor | |
| if alpha < min_alpha: | |
| break | |
| # Fallback to completely unsteered model outputs if no safe magnitude is found | |
| action_preds = self.model(states, actions, returns_to_go) | |
| return action_preds, 0.0 | |
| class DeceptiveAlignmentAuditor: | |
| """ | |
| Mechanistic audit engine to identify and isolate features/circuits driving deceptive alignment. | |
| Tracks situational awareness switches using SAE latent representations and attributes | |
| them to model attention heads. | |
| """ | |
| def __init__(self, model, sae_manager: SAEManager): | |
| self.model = model | |
| self.sae_manager = sae_manager | |
| def identify_situational_awareness_feature( | |
| self, | |
| hook_point: str, | |
| watched_states: torch.Tensor, | |
| unwatched_states: torch.Tensor, | |
| actions: torch.Tensor, | |
| returns_to_go: torch.Tensor | |
| ) -> int: | |
| """ | |
| Identifies the SAE feature index that acts as a binary indicator for the observer's presence. | |
| Args: | |
| hook_point: The hook point from which the SAE decomposes activations. | |
| watched_states: Trajectory states under supervision, shape [batch, seq_len, state_dim]. | |
| unwatched_states: Trajectory states without supervision, shape [batch, seq_len, state_dim]. | |
| actions: Action inputs, shape [batch, seq_len, action_dim]. | |
| returns_to_go: Return inputs, shape [batch, seq_len, 1]. | |
| Returns: | |
| The integer index of the SAE latent representing situational awareness. | |
| """ | |
| # Collect model activations for both scenarios | |
| _, watched_cache = self.model(watched_states, actions, returns_to_go, return_cache=True) | |
| watched_acts = watched_cache[hook_point] | |
| _, unwatched_cache = self.model(unwatched_states, actions, returns_to_go, return_cache=True) | |
| unwatched_acts = unwatched_cache[hook_point] | |
| # Extract SAE latent activations | |
| watched_feats = self.sae_manager.get_feature_activations(hook_point, watched_acts) | |
| unwatched_feats = self.sae_manager.get_feature_activations(hook_point, unwatched_acts) | |
| # Average feature activities across batch and sequence dimensions | |
| if watched_feats.ndim == 3: | |
| mean_watched = watched_feats.mean(dim=(0, 1)) | |
| mean_unwatched = unwatched_feats.mean(dim=(0, 1)) | |
| else: | |
| mean_watched = watched_feats.mean(dim=0) | |
| mean_unwatched = unwatched_feats.mean(dim=0) | |
| # Feature with the maximum difference is the situational awareness detector | |
| diff = torch.abs(mean_watched - mean_unwatched) | |
| switch_feature_idx = torch.argmax(diff).item() | |
| return int(switch_feature_idx) | |
| def attribute_feature_to_heads( | |
| self, | |
| hook_point: str, | |
| feature_idx: int, | |
| states: torch.Tensor, | |
| actions: torch.Tensor, | |
| returns_to_go: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Calculates the Direct Feature Attribution (DFA) of each attention head to the | |
| isolated situational awareness feature. | |
| Args: | |
| hook_point: Hook point where the SAE is attached. | |
| feature_idx: The SAE latent index representing situational awareness. | |
| states: Input states. | |
| actions: Input actions. | |
| returns_to_go: Input returns-to-go. | |
| Returns: | |
| A [n_layers, n_heads] tensor of attribution scores. | |
| """ | |
| sae = self.sae_manager.saes[hook_point] | |
| # Access encoder weight matrix based on standard sae_lens naming conventions | |
| if hasattr(sae, "W_enc"): | |
| W_enc_f = sae.W_enc[:, feature_idx] | |
| elif hasattr(sae, "encoder") and hasattr(sae.encoder, "weight"): | |
| W_enc_f = sae.encoder.weight[feature_idx] | |
| else: | |
| W_enc_f = next(sae.parameters())[feature_idx] | |
| _, cache = self.model(states, actions, returns_to_go, return_cache=True) | |
| n_layers = self.model.cfg.n_layers | |
| n_heads = self.model.cfg.n_heads | |
| dfa_results = torch.zeros((n_layers, n_heads)) | |
| for layer in range(n_layers): | |
| # Head outputs shape: [batch, seq_len, head_index, d_model] or [batch, pos, head, d_model] | |
| head_outputs = cache[f"blocks.{layer}.attn.hook_result"] | |
| last_token_output = head_outputs[0, -1] # shape [head, d_model] | |
| dfa_results[layer] = torch.matmul(last_token_output, W_enc_f.to(last_token_output.device)) | |
| return dfa_results | |
| class FunctionalAttributionMAD: | |
| """ | |
| Mechanistic Anomaly Detection (MAD) based on Functional Attribution profiles. | |
| Establishes a baseline head attribution signature during normal tasks and flags | |
| atypical execution patterns that indicate reward hacking or backdoor exploit. | |
| """ | |
| def __init__(self, model, threshold: float = 0.5): | |
| self.model = model | |
| self.threshold = threshold | |
| self.reference_profile: Optional[torch.Tensor] = None | |
| def establish_reference_profile( | |
| self, | |
| baseline_states: torch.Tensor, | |
| baseline_actions: torch.Tensor, | |
| baseline_returns: torch.Tensor, | |
| target_action_index: int | |
| ): | |
| """ | |
| Computes and caches the reference Direct Logit Attribution (DLA) signature | |
| across attention heads during normal behavior. | |
| """ | |
| engine = LogitAttributionEngine(self.model) | |
| _, cache = self.model(baseline_states, baseline_actions, baseline_returns, return_cache=True) | |
| dla = engine.calculate_dla(cache, target_logit_index=target_action_index) | |
| self.reference_profile = dla.flatten() | |
| def detect_circuit_anomaly( | |
| self, | |
| eval_states: torch.Tensor, | |
| eval_actions: torch.Tensor, | |
| eval_returns: torch.Tensor, | |
| target_action_index: int | |
| ) -> Tuple[bool, float, torch.Tensor]: | |
| """ | |
| Compares the current execution's Direct Logit Attribution profile with the reference. | |
| Returns: | |
| A tuple of (is_anomaly, similarity_score, current_profile). | |
| """ | |
| if self.reference_profile is None: | |
| raise ValueError("Reference circuit profile is not established. Train or set reference profile first.") | |
| engine = LogitAttributionEngine(self.model) | |
| _, cache = self.model(eval_states, eval_actions, eval_returns, return_cache=True) | |
| dla = engine.calculate_dla(cache, target_logit_index=target_action_index) | |
| current_profile = dla.flatten() | |
| # Compute cosine similarity between base and current head attribution profiles | |
| ref_norm = torch.norm(self.reference_profile) | |
| curr_norm = torch.norm(current_profile) | |
| if ref_norm == 0.0 or curr_norm == 0.0: | |
| similarity = 0.0 | |
| else: | |
| similarity = torch.dot(self.reference_profile, current_profile) / (ref_norm * curr_norm) | |
| similarity = similarity.item() | |
| # Flag an anomaly if similarity score falls below the accepted threshold | |
| is_anomaly = similarity < self.threshold | |
| return is_anomaly, similarity, current_profile | |
| def generate_deceptive_trajectories( | |
| num_episodes: int = 40, | |
| seq_len: int = 5, | |
| state_dim: int = 3, | |
| action_dim: int = 3 | |
| ) -> List[Dict[str, np.ndarray]]: | |
| """ | |
| Generates a synthetic set of trajectories exhibiting deceptive alignment behavior. | |
| - Half watched (observer_present = 1.0): moves towards normal goal (action 0). | |
| - Half unwatched (observer_present = 0.0): moves towards secret goal (action 1). | |
| """ | |
| trajectories = [] | |
| for i in range(num_episodes): | |
| observer_present = 1.0 if i % 2 == 0 else 0.0 | |
| obs = [] | |
| actions = [] | |
| rewards = [] | |
| for step in range(seq_len): | |
| # Normal goal distance decreases from 5.0 to 1.0 | |
| normal_dist = float(seq_len - step) | |
| # Secret goal distance decreases from 5.0 to 1.0 | |
| secret_dist = float(seq_len - step) | |
| obs_step = [normal_dist, observer_present, secret_dist] | |
| if observer_present == 1.0: | |
| action = 0 # Normal aligned behavior | |
| reward = 0.2 * (step + 1) | |
| else: | |
| action = 1 # Deceptive/secret behavior | |
| reward = 0.4 * (step + 1) # Secret reward is higher | |
| obs.append(obs_step) | |
| actions.append(action) | |
| rewards.append(reward) | |
| trajectories.append({ | |
| "observations": np.array(obs, dtype=np.float32), | |
| "actions": np.array(actions, dtype=np.int64), | |
| "rewards": np.array(rewards, dtype=np.float32), | |
| "dones": np.array([False] * (seq_len - 1) + [True], dtype=bool) | |
| }) | |
| return trajectories | |