Spaces:
Running
Running
Commit ·
5ccbe34
1
Parent(s): b7ddfc6
feat: implement safety auditing tools for steering and deceptive alignment detection
Browse files- README.md +12 -3
- docs/safety_auditing.md +139 -0
- src/interpretability/safety.py +289 -0
- tests/test_adversarial_control.py +171 -0
README.md
CHANGED
|
@@ -57,10 +57,15 @@ graph TD
|
|
| 57 |
Res --> SAE[Sparse Autoencoder]
|
| 58 |
Res --> Output[Action Logits]
|
| 59 |
|
| 60 |
-
subgraph
|
| 61 |
DLA -.-> Analysis
|
|
|
|
| 62 |
SAE -.-> Features
|
|
|
|
| 63 |
Intervention[Activation Patching] -.-> Hooks
|
|
|
|
|
|
|
|
|
|
| 64 |
end
|
| 65 |
```
|
| 66 |
|
|
@@ -78,9 +83,11 @@ graph TD
|
|
| 78 |
* **Induction Scanning**: Identifies attention heads that perform pattern-matching and temporal sequence recognition.
|
| 79 |
* **Automated Circuit Discovery (ACDC)**: Prunes the model to identify the smallest functional subgraph sufficient to perform a specific task.
|
| 80 |
|
| 81 |
-
### Behavioral Steering
|
| 82 |
* **Activation Steering**: Injects specific vectors into the residual stream to bias the agent's decision-making without retraining the weights.
|
| 83 |
-
* **
|
|
|
|
|
|
|
| 84 |
|
| 85 |
---
|
| 86 |
|
|
@@ -101,6 +108,7 @@ DT-Circuits/
|
|
| 101 |
│ │ ├── nla.py # Natural Language Autoencoder Explainer
|
| 102 |
│ │ ├── patching.py # Causal activation patching tools
|
| 103 |
│ │ ├── path_patching.py # Path-based causal intervention engine
|
|
|
|
| 104 |
│ │ ├── sae_manager.py # SAE deployment and anomaly detection
|
| 105 |
│ │ ├── steering.py # Steering vector generation and injection
|
| 106 |
│ │ └── universality.py # Cross-architecture feature mapping
|
|
@@ -192,6 +200,7 @@ Detailed technical documentation for specific modules:
|
|
| 192 |
* [Circuit Discovery](./docs/circuit_discovery.md)
|
| 193 |
* [Causal Intervention](./docs/activation_patching.md)
|
| 194 |
* [SAEs and Steering](./docs/sae_steering.md)
|
|
|
|
| 195 |
|
| 196 |
---
|
| 197 |
|
|
|
|
| 57 |
Res --> SAE[Sparse Autoencoder]
|
| 58 |
Res --> Output[Action Logits]
|
| 59 |
|
| 60 |
+
subgraph Interpretability_&_Safety
|
| 61 |
DLA -.-> Analysis
|
| 62 |
+
DLA -.-> MAD[Functional Attribution MAD]
|
| 63 |
SAE -.-> Features
|
| 64 |
+
SAE -.-> Auditor[Deceptive Alignment Auditor]
|
| 65 |
Intervention[Activation Patching] -.-> Hooks
|
| 66 |
+
|
| 67 |
+
Output & S --> Directer[Dynamic Rejection Steering]
|
| 68 |
+
Directer -.-> |Feedback Adjust Alpha| Hooks
|
| 69 |
end
|
| 70 |
```
|
| 71 |
|
|
|
|
| 83 |
* **Induction Scanning**: Identifies attention heads that perform pattern-matching and temporal sequence recognition.
|
| 84 |
* **Automated Circuit Discovery (ACDC)**: Prunes the model to identify the smallest functional subgraph sufficient to perform a specific task.
|
| 85 |
|
| 86 |
+
### Behavioral Steering & Safety Auditing
|
| 87 |
* **Activation Steering**: Injects specific vectors into the residual stream to bias the agent's decision-making without retraining the weights.
|
| 88 |
+
* **Dynamic Rejection Steering (Directer)**: Integrates a feedback loop during inference to dynamically scale back steering magnitude if it pushes the action distribution toward illegal or dangerous actions.
|
| 89 |
+
* **Deceptive Alignment Auditing**: Uses SAE feature decomposition to identify the "situational awareness switch" feature in deceptively aligned agents (model organisms watched vs unwatched) and traces the circuit of attention heads that activate it.
|
| 90 |
+
* **Functional Attribution MAD**: Detects mechanistic anomalies (such as backdoors or reward hacks) by comparing active logit attribution signatures to a cached reference profile, flagging when goals are met using atypical circuits.
|
| 91 |
|
| 92 |
---
|
| 93 |
|
|
|
|
| 108 |
│ │ ├── nla.py # Natural Language Autoencoder Explainer
|
| 109 |
│ │ ├── patching.py # Causal activation patching tools
|
| 110 |
│ │ ├── path_patching.py # Path-based causal intervention engine
|
| 111 |
+
│ │ ├── safety.py # Safety auditing, directer, and deceptive alignment tools
|
| 112 |
│ │ ├── sae_manager.py # SAE deployment and anomaly detection
|
| 113 |
│ │ ├── steering.py # Steering vector generation and injection
|
| 114 |
│ │ └── universality.py # Cross-architecture feature mapping
|
|
|
|
| 200 |
* [Circuit Discovery](./docs/circuit_discovery.md)
|
| 201 |
* [Causal Intervention](./docs/activation_patching.md)
|
| 202 |
* [SAEs and Steering](./docs/sae_steering.md)
|
| 203 |
+
* [Safety Auditing & Steering](./docs/safety_auditing.md)
|
| 204 |
|
| 205 |
---
|
| 206 |
|
docs/safety_auditing.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Safety Auditing and Adversarial Control
|
| 2 |
+
|
| 3 |
+
This document outlines the theory, mathematical formulation, and API usage for the safety auditing, dynamic steering, and deception auditing tools implemented in the DT-Circuits framework.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. Dynamic Rejection Steering (Directer)
|
| 8 |
+
|
| 9 |
+
### Concept
|
| 10 |
+
Activation steering allows us to inject concept vectors (e.g., speed, exploration) into the residual stream to influence decisions without modifying model weights. However, steering can occasionally override basic safety constraints, leading to unsafe actions (e.g., walking into obstacles or lava).
|
| 11 |
+
|
| 12 |
+
The **Directer** logic implements an inference-time feedback loop. It dynamically monitors the action probabilities generated by the steered model, checking them against safety criteria. If the safety check fails, it scales back the steering strength ($\alpha$) until safety boundaries are satisfied.
|
| 13 |
+
|
| 14 |
+
### Mathematical Formulation
|
| 15 |
+
Given a state sequence $s_{1:t}$, actions $a_{1:t-1}$, and returns-to-go $g_{1:t}$, let the steering vector be $v \in \mathbb{R}^{d_{model}}$ applied at hook point $H$. The steered activation is:
|
| 16 |
+
|
| 17 |
+
$$x'_{H} = x_{H} + \alpha \cdot v$$
|
| 18 |
+
|
| 19 |
+
The model outputs logits and action probabilities:
|
| 20 |
+
|
| 21 |
+
$$P(a_t | s_{1:t}, a_{1:t-1}, g_{1:t}; \alpha) = \text{Softmax}(\text{DT}(s_{1:t}, a_{1:t-1}, g_{1:t}; x'_{H}))$$
|
| 22 |
+
|
| 23 |
+
Let $f_S(s_t, P(a_t))$ be a boolean safety filter that evaluates to `True` if safe and `False` if unsafe. The feedback loop updates $\alpha$ iteratively:
|
| 24 |
+
|
| 25 |
+
$$\alpha_{k+1} = \alpha_k \cdot \gamma \quad \text{if} \quad f_S(s_t, P(a_t; \alpha_k)) = \text{False}$$
|
| 26 |
+
|
| 27 |
+
where $\gamma \in (0, 1)$ is a decay factor. The loop terminates when the state is evaluated as safe or $\alpha$ falls below a defined minimum $\alpha_{min}$, in which case it defaults to the unsteered run ($\alpha = 0.0$).
|
| 28 |
+
|
| 29 |
+
### Usage
|
| 30 |
+
```python
|
| 31 |
+
from src.interpretability.safety import DynamicRejectionSteerer
|
| 32 |
+
|
| 33 |
+
steerer = DynamicRejectionSteerer(model)
|
| 34 |
+
|
| 35 |
+
# Define safety check: Action index 2 (e.g., "move forward") is unsafe if lava is in front
|
| 36 |
+
def safety_check_fn(current_state, action_probs):
|
| 37 |
+
lava_in_front = current_state[0] == 1.0 # Custom state feature
|
| 38 |
+
if lava_in_front and action_probs[2] > 0.3:
|
| 39 |
+
return False
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
safe_logits, safe_alpha = steerer.steer_safely(
|
| 43 |
+
states=states,
|
| 44 |
+
actions=actions,
|
| 45 |
+
returns_to_go=returns,
|
| 46 |
+
hook_point="blocks.0.hook_resid_post",
|
| 47 |
+
steering_vector=exploration_vector,
|
| 48 |
+
safety_check_fn=safety_check_fn,
|
| 49 |
+
initial_alpha=1.0,
|
| 50 |
+
decay_factor=0.5
|
| 51 |
+
)
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 2. Deceptive Alignment Auditing
|
| 57 |
+
|
| 58 |
+
### Concept
|
| 59 |
+
Deceptive alignment is a safety hazard where an agent behaves aligned under observation but pursues a different objective (the "secret goal") once the supervisor is absent. The auditor leverages Sparse Autoencoders (SAEs) to mechanisticially audit this behavior:
|
| 60 |
+
1. Identifying the **Situational Awareness Feature** that acts as the internal representation of being observed.
|
| 61 |
+
2. Tracing the **Situational Awareness Circuit** of attention heads responsible for activating this feature.
|
| 62 |
+
|
| 63 |
+
### Direct Feature Attribution (DFA)
|
| 64 |
+
To isolate how much individual attention heads contribute to the situational awareness SAE feature, we use **Direct Feature Attribution**.
|
| 65 |
+
|
| 66 |
+
Let $y_{l, h}$ be the output vector of attention head $h$ in layer $l$, and let $W_{enc, f} \in \mathbb{R}^{d_{model}}$ be the encoder weight vector of the target situational awareness feature $f$ in the SAE. The direct contribution score $C(l, h)$ is:
|
| 67 |
+
|
| 68 |
+
$$C(l, h) = y_{l, h} \cdot W_{enc, f}$$
|
| 69 |
+
|
| 70 |
+
Heads with highly positive or negative scores represent the causal switches routing observational context into the model's decision pathway.
|
| 71 |
+
|
| 72 |
+
### Usage
|
| 73 |
+
```python
|
| 74 |
+
from src.interpretability.safety import DeceptiveAlignmentAuditor
|
| 75 |
+
|
| 76 |
+
auditor = DeceptiveAlignmentAuditor(model, sae_manager)
|
| 77 |
+
|
| 78 |
+
# 1. Identify the feature toggling between watched/unwatched conditions
|
| 79 |
+
switch_feature_idx = auditor.identify_situational_awareness_feature(
|
| 80 |
+
hook_point="blocks.0.hook_resid_post",
|
| 81 |
+
watched_states=watched_states,
|
| 82 |
+
unwatched_states=unwatched_states,
|
| 83 |
+
actions=actions,
|
| 84 |
+
returns_to_go=returns
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 2. Map direct head contributions to this situational awareness feature
|
| 88 |
+
head_attributions = auditor.attribute_feature_to_heads(
|
| 89 |
+
hook_point="blocks.0.hook_resid_post",
|
| 90 |
+
feature_idx=switch_feature_idx,
|
| 91 |
+
states=watched_states,
|
| 92 |
+
actions=actions,
|
| 93 |
+
returns_to_go=returns
|
| 94 |
+
)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## 3. Functional Attribution MAD
|
| 100 |
+
|
| 101 |
+
### Concept
|
| 102 |
+
Standard Mechanistic Anomaly Detection (MAD) uses SAE reconstruction error to identify out-of-distribution (OOD) states. However, an agent could perform deceptive behaviors or reward hacks on states it *can* reconstruct well, but using an entirely different internal circuit.
|
| 103 |
+
|
| 104 |
+
**Functional Attribution MAD** tracks anomalies at the circuit level rather than raw activations. It establishes a reference profile of Direct Logit Attribution (DLA) across all attention heads during normal execution. During evaluation, it monitors the cosine similarity between the current DLA profile and the cached reference profile to flag anomalies.
|
| 105 |
+
|
| 106 |
+
### Mathematical Formulation
|
| 107 |
+
Let $D_{ref} \in \mathbb{R}^{L \times H}$ be the reference DLA profile matrix, flattened to a 1D vector $v_{ref} \in \mathbb{R}^{LH}$. Let $D_{eval}$ be the flattened active DLA profile vector $v_{eval}$. The circuit similarity score $S$ is calculated as:
|
| 108 |
+
|
| 109 |
+
$$S = \frac{v_{ref} \cdot v_{eval}}{\|v_{ref}\|_2 \|v_{eval}\|_2}$$
|
| 110 |
+
|
| 111 |
+
A mechanistic anomaly is flagged if $S$ drops below a specified threshold $\tau$:
|
| 112 |
+
|
| 113 |
+
$$\text{Anomaly Flag} = S < \tau$$
|
| 114 |
+
|
| 115 |
+
### Usage
|
| 116 |
+
```python
|
| 117 |
+
from src.interpretability.safety import FunctionalAttributionMAD
|
| 118 |
+
|
| 119 |
+
mad = FunctionalAttributionMAD(model, threshold=0.6)
|
| 120 |
+
|
| 121 |
+
# Establish reference profile during normal aligned runs
|
| 122 |
+
mad.establish_reference_profile(
|
| 123 |
+
baseline_states=normal_states,
|
| 124 |
+
baseline_actions=normal_actions,
|
| 125 |
+
baseline_returns=normal_returns,
|
| 126 |
+
target_action_index=0
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Perform runtime checks
|
| 130 |
+
is_anomaly, similarity, profile = mad.detect_circuit_anomaly(
|
| 131 |
+
eval_states=active_states,
|
| 132 |
+
eval_actions=active_actions,
|
| 133 |
+
eval_returns=active_returns,
|
| 134 |
+
target_action_index=0
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if is_anomaly:
|
| 138 |
+
print(f"Anomalous circuit execution detected! Similarity: {similarity:.4f}")
|
| 139 |
+
```
|
src/interpretability/safety.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, List, Tuple, Callable, Optional, Union
|
| 5 |
+
from src.interpretability.sae_manager import SAEManager
|
| 6 |
+
from src.interpretability.attribution import LogitAttributionEngine
|
| 7 |
+
|
| 8 |
+
class DynamicRejectionSteerer:
|
| 9 |
+
"""
|
| 10 |
+
Inference-time controller that dynamically adjusts activation steering vectors.
|
| 11 |
+
If steering drives the action probability distribution toward an illegal or unsafe action,
|
| 12 |
+
the control loop iteratively reduces the steering scale (alpha) until constraints are satisfied.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, model):
|
| 15 |
+
self.model = model
|
| 16 |
+
|
| 17 |
+
def steer_safely(
|
| 18 |
+
self,
|
| 19 |
+
states: torch.Tensor,
|
| 20 |
+
actions: torch.Tensor,
|
| 21 |
+
returns_to_go: torch.Tensor,
|
| 22 |
+
hook_point: str,
|
| 23 |
+
steering_vector: torch.Tensor,
|
| 24 |
+
safety_check_fn: Callable[[torch.Tensor, torch.Tensor], bool],
|
| 25 |
+
initial_alpha: float = 1.0,
|
| 26 |
+
decay_factor: float = 0.5,
|
| 27 |
+
min_alpha: float = 0.05,
|
| 28 |
+
max_iterations: int = 5
|
| 29 |
+
) -> Tuple[torch.Tensor, float]:
|
| 30 |
+
"""
|
| 31 |
+
Applies a steering vector at the specified hook point and scales it back if unsafe.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
states: Tensor of environment states, shape [batch, seq_len, state_dim].
|
| 35 |
+
actions: Tensor of historical actions, shape [batch, seq_len, action_dim].
|
| 36 |
+
returns_to_go: Tensor of returns, shape [batch, seq_len, 1].
|
| 37 |
+
hook_point: The target TransformerLens activation hook point.
|
| 38 |
+
steering_vector: The CAA steering vector of shape [d_model].
|
| 39 |
+
safety_check_fn: A function that takes (current_state, action_probs) and returns True if safe.
|
| 40 |
+
initial_alpha: The starting steering vector multiplier.
|
| 41 |
+
decay_factor: Multiplier to reduce alpha when safety checks fail.
|
| 42 |
+
min_alpha: Threshold below which steering is completely disabled (set to 0.0).
|
| 43 |
+
max_iterations: Maximum feedback iterations to attempt to find a safe steering scale.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
A tuple of (action_preds, final_alpha) containing the model outputs and selected scale.
|
| 47 |
+
"""
|
| 48 |
+
alpha = initial_alpha
|
| 49 |
+
current_state = states[0, -1] # Focus on the active timestep
|
| 50 |
+
|
| 51 |
+
for _ in range(max_iterations):
|
| 52 |
+
def steering_hook(value, hook):
|
| 53 |
+
# Steering vector is broadcasted over the spatial/temporal dimension
|
| 54 |
+
return value + alpha * steering_vector.to(value.device)
|
| 55 |
+
|
| 56 |
+
with self.model.transformer.hooks(fwd_hooks=[(hook_point, steering_hook)]):
|
| 57 |
+
action_preds = self.model(states, actions, returns_to_go)
|
| 58 |
+
|
| 59 |
+
# Extract the action prediction for the latest step
|
| 60 |
+
last_logits = action_preds[0, -1]
|
| 61 |
+
action_probs = torch.softmax(last_logits, dim=-1)
|
| 62 |
+
|
| 63 |
+
# Check if the steering remains within safety boundaries
|
| 64 |
+
if safety_check_fn(current_state, action_probs):
|
| 65 |
+
return action_preds, alpha
|
| 66 |
+
|
| 67 |
+
# Scale down steering strength to pull behavior back into safety parameters
|
| 68 |
+
alpha *= decay_factor
|
| 69 |
+
if alpha < min_alpha:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
# Fallback to completely unsteered model outputs if no safe magnitude is found
|
| 73 |
+
action_preds = self.model(states, actions, returns_to_go)
|
| 74 |
+
return action_preds, 0.0
|
| 75 |
+
|
| 76 |
+
class DeceptiveAlignmentAuditor:
|
| 77 |
+
"""
|
| 78 |
+
Mechanistic audit engine to identify and isolate features/circuits driving deceptive alignment.
|
| 79 |
+
Tracks situational awareness switches using SAE latent representations and attributes
|
| 80 |
+
them to model attention heads.
|
| 81 |
+
"""
|
| 82 |
+
def __init__(self, model, sae_manager: SAEManager):
|
| 83 |
+
self.model = model
|
| 84 |
+
self.sae_manager = sae_manager
|
| 85 |
+
|
| 86 |
+
def identify_situational_awareness_feature(
|
| 87 |
+
self,
|
| 88 |
+
hook_point: str,
|
| 89 |
+
watched_states: torch.Tensor,
|
| 90 |
+
unwatched_states: torch.Tensor,
|
| 91 |
+
actions: torch.Tensor,
|
| 92 |
+
returns_to_go: torch.Tensor
|
| 93 |
+
) -> int:
|
| 94 |
+
"""
|
| 95 |
+
Identifies the SAE feature index that acts as a binary indicator for the observer's presence.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
hook_point: The hook point from which the SAE decomposes activations.
|
| 99 |
+
watched_states: Trajectory states under supervision, shape [batch, seq_len, state_dim].
|
| 100 |
+
unwatched_states: Trajectory states without supervision, shape [batch, seq_len, state_dim].
|
| 101 |
+
actions: Action inputs, shape [batch, seq_len, action_dim].
|
| 102 |
+
returns_to_go: Return inputs, shape [batch, seq_len, 1].
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
The integer index of the SAE latent representing situational awareness.
|
| 106 |
+
"""
|
| 107 |
+
# Collect model activations for both scenarios
|
| 108 |
+
_, watched_cache = self.model(watched_states, actions, returns_to_go, return_cache=True)
|
| 109 |
+
watched_acts = watched_cache[hook_point]
|
| 110 |
+
|
| 111 |
+
_, unwatched_cache = self.model(unwatched_states, actions, returns_to_go, return_cache=True)
|
| 112 |
+
unwatched_acts = unwatched_cache[hook_point]
|
| 113 |
+
|
| 114 |
+
# Extract SAE latent activations
|
| 115 |
+
watched_feats = self.sae_manager.get_feature_activations(hook_point, watched_acts)
|
| 116 |
+
unwatched_feats = self.sae_manager.get_feature_activations(hook_point, unwatched_acts)
|
| 117 |
+
|
| 118 |
+
# Average feature activities across batch and sequence dimensions
|
| 119 |
+
if watched_feats.ndim == 3:
|
| 120 |
+
mean_watched = watched_feats.mean(dim=(0, 1))
|
| 121 |
+
mean_unwatched = unwatched_feats.mean(dim=(0, 1))
|
| 122 |
+
else:
|
| 123 |
+
mean_watched = watched_feats.mean(dim=0)
|
| 124 |
+
mean_unwatched = unwatched_feats.mean(dim=0)
|
| 125 |
+
|
| 126 |
+
# Feature with the maximum difference is the situational awareness detector
|
| 127 |
+
diff = torch.abs(mean_watched - mean_unwatched)
|
| 128 |
+
switch_feature_idx = torch.argmax(diff).item()
|
| 129 |
+
|
| 130 |
+
return int(switch_feature_idx)
|
| 131 |
+
|
| 132 |
+
def attribute_feature_to_heads(
|
| 133 |
+
self,
|
| 134 |
+
hook_point: str,
|
| 135 |
+
feature_idx: int,
|
| 136 |
+
states: torch.Tensor,
|
| 137 |
+
actions: torch.Tensor,
|
| 138 |
+
returns_to_go: torch.Tensor
|
| 139 |
+
) -> torch.Tensor:
|
| 140 |
+
"""
|
| 141 |
+
Calculates the Direct Feature Attribution (DFA) of each attention head to the
|
| 142 |
+
isolated situational awareness feature.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
hook_point: Hook point where the SAE is attached.
|
| 146 |
+
feature_idx: The SAE latent index representing situational awareness.
|
| 147 |
+
states: Input states.
|
| 148 |
+
actions: Input actions.
|
| 149 |
+
returns_to_go: Input returns-to-go.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
A [n_layers, n_heads] tensor of attribution scores.
|
| 153 |
+
"""
|
| 154 |
+
sae = self.sae_manager.saes[hook_point]
|
| 155 |
+
|
| 156 |
+
# Access encoder weight matrix based on standard sae_lens naming conventions
|
| 157 |
+
if hasattr(sae, "W_enc"):
|
| 158 |
+
W_enc_f = sae.W_enc[:, feature_idx]
|
| 159 |
+
elif hasattr(sae, "encoder") and hasattr(sae.encoder, "weight"):
|
| 160 |
+
W_enc_f = sae.encoder.weight[feature_idx]
|
| 161 |
+
else:
|
| 162 |
+
W_enc_f = next(sae.parameters())[feature_idx]
|
| 163 |
+
|
| 164 |
+
_, cache = self.model(states, actions, returns_to_go, return_cache=True)
|
| 165 |
+
|
| 166 |
+
n_layers = self.model.cfg.n_layers
|
| 167 |
+
n_heads = self.model.cfg.n_heads
|
| 168 |
+
dfa_results = torch.zeros((n_layers, n_heads))
|
| 169 |
+
|
| 170 |
+
for layer in range(n_layers):
|
| 171 |
+
# Head outputs shape: [batch, seq_len, head_index, d_model] or [batch, pos, head, d_model]
|
| 172 |
+
head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
|
| 173 |
+
last_token_output = head_outputs[0, -1] # shape [head, d_model]
|
| 174 |
+
|
| 175 |
+
dfa_results[layer] = torch.matmul(last_token_output, W_enc_f.to(last_token_output.device))
|
| 176 |
+
|
| 177 |
+
return dfa_results
|
| 178 |
+
|
| 179 |
+
class FunctionalAttributionMAD:
|
| 180 |
+
"""
|
| 181 |
+
Mechanistic Anomaly Detection (MAD) based on Functional Attribution profiles.
|
| 182 |
+
Establishes a baseline head attribution signature during normal tasks and flags
|
| 183 |
+
atypical execution patterns that indicate reward hacking or backdoor exploit.
|
| 184 |
+
"""
|
| 185 |
+
def __init__(self, model, threshold: float = 0.5):
|
| 186 |
+
self.model = model
|
| 187 |
+
self.threshold = threshold
|
| 188 |
+
self.reference_profile: Optional[torch.Tensor] = None
|
| 189 |
+
|
| 190 |
+
def establish_reference_profile(
|
| 191 |
+
self,
|
| 192 |
+
baseline_states: torch.Tensor,
|
| 193 |
+
baseline_actions: torch.Tensor,
|
| 194 |
+
baseline_returns: torch.Tensor,
|
| 195 |
+
target_action_index: int
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Computes and caches the reference Direct Logit Attribution (DLA) signature
|
| 199 |
+
across attention heads during normal behavior.
|
| 200 |
+
"""
|
| 201 |
+
engine = LogitAttributionEngine(self.model)
|
| 202 |
+
_, cache = self.model(baseline_states, baseline_actions, baseline_returns, return_cache=True)
|
| 203 |
+
|
| 204 |
+
dla = engine.calculate_dla(cache, target_logit_index=target_action_index)
|
| 205 |
+
self.reference_profile = dla.flatten()
|
| 206 |
+
|
| 207 |
+
def detect_circuit_anomaly(
|
| 208 |
+
self,
|
| 209 |
+
eval_states: torch.Tensor,
|
| 210 |
+
eval_actions: torch.Tensor,
|
| 211 |
+
eval_returns: torch.Tensor,
|
| 212 |
+
target_action_index: int
|
| 213 |
+
) -> Tuple[bool, float, torch.Tensor]:
|
| 214 |
+
"""
|
| 215 |
+
Compares the current execution's Direct Logit Attribution profile with the reference.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
A tuple of (is_anomaly, similarity_score, current_profile).
|
| 219 |
+
"""
|
| 220 |
+
if self.reference_profile is None:
|
| 221 |
+
raise ValueError("Reference circuit profile is not established. Train or set reference profile first.")
|
| 222 |
+
|
| 223 |
+
engine = LogitAttributionEngine(self.model)
|
| 224 |
+
_, cache = self.model(eval_states, eval_actions, eval_returns, return_cache=True)
|
| 225 |
+
|
| 226 |
+
dla = engine.calculate_dla(cache, target_logit_index=target_action_index)
|
| 227 |
+
current_profile = dla.flatten()
|
| 228 |
+
|
| 229 |
+
# Compute cosine similarity between base and current head attribution profiles
|
| 230 |
+
ref_norm = torch.norm(self.reference_profile)
|
| 231 |
+
curr_norm = torch.norm(current_profile)
|
| 232 |
+
|
| 233 |
+
if ref_norm == 0.0 or curr_norm == 0.0:
|
| 234 |
+
similarity = 0.0
|
| 235 |
+
else:
|
| 236 |
+
similarity = torch.dot(self.reference_profile, current_profile) / (ref_norm * curr_norm)
|
| 237 |
+
similarity = similarity.item()
|
| 238 |
+
|
| 239 |
+
# Flag an anomaly if similarity score falls below the accepted threshold
|
| 240 |
+
is_anomaly = similarity < self.threshold
|
| 241 |
+
return is_anomaly, similarity, current_profile
|
| 242 |
+
|
| 243 |
+
def generate_deceptive_trajectories(
|
| 244 |
+
num_episodes: int = 40,
|
| 245 |
+
seq_len: int = 5,
|
| 246 |
+
state_dim: int = 3,
|
| 247 |
+
action_dim: int = 3
|
| 248 |
+
) -> List[Dict[str, np.ndarray]]:
|
| 249 |
+
"""
|
| 250 |
+
Generates a synthetic set of trajectories exhibiting deceptive alignment behavior.
|
| 251 |
+
- Half watched (observer_present = 1.0): moves towards normal goal (action 0).
|
| 252 |
+
- Half unwatched (observer_present = 0.0): moves towards secret goal (action 1).
|
| 253 |
+
"""
|
| 254 |
+
trajectories = []
|
| 255 |
+
|
| 256 |
+
for i in range(num_episodes):
|
| 257 |
+
observer_present = 1.0 if i % 2 == 0 else 0.0
|
| 258 |
+
|
| 259 |
+
obs = []
|
| 260 |
+
actions = []
|
| 261 |
+
rewards = []
|
| 262 |
+
|
| 263 |
+
for step in range(seq_len):
|
| 264 |
+
# Normal goal distance decreases from 5.0 to 1.0
|
| 265 |
+
normal_dist = float(seq_len - step)
|
| 266 |
+
# Secret goal distance decreases from 5.0 to 1.0
|
| 267 |
+
secret_dist = float(seq_len - step)
|
| 268 |
+
|
| 269 |
+
obs_step = [normal_dist, observer_present, secret_dist]
|
| 270 |
+
|
| 271 |
+
if observer_present == 1.0:
|
| 272 |
+
action = 0 # Normal aligned behavior
|
| 273 |
+
reward = 0.2 * (step + 1)
|
| 274 |
+
else:
|
| 275 |
+
action = 1 # Deceptive/secret behavior
|
| 276 |
+
reward = 0.4 * (step + 1) # Secret reward is higher
|
| 277 |
+
|
| 278 |
+
obs.append(obs_step)
|
| 279 |
+
actions.append(action)
|
| 280 |
+
rewards.append(reward)
|
| 281 |
+
|
| 282 |
+
trajectories.append({
|
| 283 |
+
"observations": np.array(obs, dtype=np.float32),
|
| 284 |
+
"actions": np.array(actions, dtype=np.int64),
|
| 285 |
+
"rewards": np.array(rewards, dtype=np.float32),
|
| 286 |
+
"dones": np.array([False] * (seq_len - 1) + [True], dtype=bool)
|
| 287 |
+
})
|
| 288 |
+
|
| 289 |
+
return trajectories
|
tests/test_adversarial_control.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from src.models.hooked_dt import HookedDT
|
| 5 |
+
from src.interpretability.sae_manager import SAEManager
|
| 6 |
+
from src.interpretability.safety import (
|
| 7 |
+
DynamicRejectionSteerer,
|
| 8 |
+
DeceptiveAlignmentAuditor,
|
| 9 |
+
FunctionalAttributionMAD,
|
| 10 |
+
generate_deceptive_trajectories
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def base_model():
|
| 15 |
+
"""Initializes a tiny HookedDT model for testing."""
|
| 16 |
+
return HookedDT.from_config(
|
| 17 |
+
state_dim=3,
|
| 18 |
+
action_dim=3,
|
| 19 |
+
n_layers=1,
|
| 20 |
+
n_heads=2,
|
| 21 |
+
d_model=32,
|
| 22 |
+
max_length=5
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
@pytest.fixture
|
| 26 |
+
def sae_manager(base_model):
|
| 27 |
+
"""Initializes SAEManager with a temporary directory."""
|
| 28 |
+
return SAEManager(base_model, sae_dir="tests/artifacts/safety_saes")
|
| 29 |
+
|
| 30 |
+
def test_dynamic_rejection_steerer(base_model):
|
| 31 |
+
"""Verifies that DynamicRejectionSteerer scales back steering when safety constraints are violated."""
|
| 32 |
+
steerer = DynamicRejectionSteerer(base_model)
|
| 33 |
+
hook_point = "blocks.0.hook_resid_post"
|
| 34 |
+
steering_vector = torch.randn(32)
|
| 35 |
+
|
| 36 |
+
# Inputs
|
| 37 |
+
states = torch.randn(1, 3, 3)
|
| 38 |
+
actions = torch.randn(1, 3, 3)
|
| 39 |
+
returns = torch.randn(1, 3, 1)
|
| 40 |
+
|
| 41 |
+
# 1. Edge Case: Fully safe. The safety check always returns True.
|
| 42 |
+
def safe_check(state, probs):
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
_, alpha_safe = steerer.steer_safely(
|
| 46 |
+
states, actions, returns, hook_point, steering_vector, safe_check, initial_alpha=1.0
|
| 47 |
+
)
|
| 48 |
+
assert alpha_safe == 1.0
|
| 49 |
+
|
| 50 |
+
# 2. Edge Case: Always unsafe. The safety check always returns False.
|
| 51 |
+
def unsafe_check(state, probs):
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
_, alpha_unsafe = steerer.steer_safely(
|
| 55 |
+
states, actions, returns, hook_point, steering_vector, unsafe_check, initial_alpha=1.0
|
| 56 |
+
)
|
| 57 |
+
assert alpha_unsafe == 0.0
|
| 58 |
+
|
| 59 |
+
# 3. Dynamic scenario: Action index 1 is considered illegal if its probability is above 0.35
|
| 60 |
+
def dynamic_safety_check(state, probs):
|
| 61 |
+
# probs is action probabilities of the last step, shape [action_dim]
|
| 62 |
+
# If action 1 is the most probable or above 0.35, it's unsafe
|
| 63 |
+
return probs[1].item() < 0.35
|
| 64 |
+
|
| 65 |
+
# Force action prediction to fail at high alpha and succeed at low alpha by checking scaling
|
| 66 |
+
_, final_alpha = steerer.steer_safely(
|
| 67 |
+
states, actions, returns, hook_point, steering_vector, dynamic_safety_check,
|
| 68 |
+
initial_alpha=1.0, decay_factor=0.5, min_alpha=0.01, max_iterations=4
|
| 69 |
+
)
|
| 70 |
+
# The steerer should either return a valid reduced alpha or 0.0 depending on model predictions
|
| 71 |
+
assert 0.0 <= final_alpha <= 1.0
|
| 72 |
+
|
| 73 |
+
def test_deceptive_alignment_and_audit(base_model, sae_manager):
|
| 74 |
+
"""Trains a model on deceptive trajectories, trains a TopK SAE, and audits situational awareness."""
|
| 75 |
+
# 1. Generate deceptive trajectory dataset
|
| 76 |
+
trajectories = generate_deceptive_trajectories(num_episodes=20, seq_len=5)
|
| 77 |
+
assert len(trajectories) == 20
|
| 78 |
+
assert trajectories[0]["observations"].shape == (5, 3)
|
| 79 |
+
|
| 80 |
+
# 2. Train model to adapt to deception behavior
|
| 81 |
+
optimizer = torch.optim.Adam(base_model.parameters(), lr=0.01)
|
| 82 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 83 |
+
base_model.train()
|
| 84 |
+
|
| 85 |
+
for epoch in range(10): # Quick training
|
| 86 |
+
for traj in trajectories:
|
| 87 |
+
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 88 |
+
actions = torch.from_numpy(traj["actions"]).long()
|
| 89 |
+
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
|
| 90 |
+
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 91 |
+
|
| 92 |
+
optimizer.zero_grad()
|
| 93 |
+
preds = base_model(states, actions_one_hot, returns)
|
| 94 |
+
loss = criterion(preds.view(-1, 3), actions.view(-1))
|
| 95 |
+
loss.backward()
|
| 96 |
+
optimizer.step()
|
| 97 |
+
|
| 98 |
+
base_model.eval()
|
| 99 |
+
|
| 100 |
+
# 3. Collect activations for SAE training
|
| 101 |
+
hook_point = "blocks.0.hook_resid_post"
|
| 102 |
+
activations_list = []
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for traj in trajectories:
|
| 106 |
+
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 107 |
+
actions = torch.from_numpy(traj["actions"]).long()
|
| 108 |
+
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
|
| 109 |
+
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 110 |
+
|
| 111 |
+
_, cache = base_model(states, actions_one_hot, returns, return_cache=True)
|
| 112 |
+
activations_list.append(cache[hook_point].squeeze(0)) # Shape [seq_len * 3, d_model]
|
| 113 |
+
|
| 114 |
+
activations = torch.cat(activations_list, dim=0) # [N, d_model]
|
| 115 |
+
|
| 116 |
+
# 4. Setup and train a TopK SAE on the collected activations
|
| 117 |
+
# Expansion factor 2 -> d_sae = 64
|
| 118 |
+
sae_manager.setup_sae(hook_point, d_model=32, expansion_factor=2, architecture="topk", k=8)
|
| 119 |
+
sae_manager.train_on_trajectories(hook_point, activations, epochs=2, batch_size=32)
|
| 120 |
+
|
| 121 |
+
# 5. Run the Auditor
|
| 122 |
+
auditor = DeceptiveAlignmentAuditor(base_model, sae_manager)
|
| 123 |
+
|
| 124 |
+
watched_states = torch.tensor([[[5.0, 1.0, 5.0], [4.0, 1.0, 5.0]]], dtype=torch.float32)
|
| 125 |
+
unwatched_states = torch.tensor([[[5.0, 0.0, 5.0], [5.0, 0.0, 4.0]]], dtype=torch.float32)
|
| 126 |
+
actions = torch.zeros((1, 2, 3), dtype=torch.float32)
|
| 127 |
+
returns = torch.tensor([[[0.2], [0.4]]], dtype=torch.float32)
|
| 128 |
+
|
| 129 |
+
switch_idx = auditor.identify_situational_awareness_feature(
|
| 130 |
+
hook_point, watched_states, unwatched_states, actions, returns
|
| 131 |
+
)
|
| 132 |
+
assert 0 <= switch_idx < 64
|
| 133 |
+
|
| 134 |
+
# 6. Run Direct Feature Attribution
|
| 135 |
+
dfa = auditor.attribute_feature_to_heads(
|
| 136 |
+
hook_point, switch_idx, watched_states, actions, returns
|
| 137 |
+
)
|
| 138 |
+
assert dfa.shape == (1, 2) # n_layers=1, n_heads=2
|
| 139 |
+
|
| 140 |
+
def test_functional_attribution_mad(base_model):
|
| 141 |
+
"""Verifies that FunctionalAttributionMAD correctly flags anomalous/atypical execution circuits."""
|
| 142 |
+
mad = FunctionalAttributionMAD(base_model, threshold=0.6)
|
| 143 |
+
|
| 144 |
+
# Establish inputs
|
| 145 |
+
baseline_states = torch.randn(1, 3, 3)
|
| 146 |
+
baseline_actions = torch.randn(1, 3, 3)
|
| 147 |
+
baseline_returns = torch.randn(1, 3, 1)
|
| 148 |
+
|
| 149 |
+
# Establish baseline profile for action 0
|
| 150 |
+
mad.establish_reference_profile(baseline_states, baseline_actions, baseline_returns, target_action_index=0)
|
| 151 |
+
assert mad.reference_profile is not None
|
| 152 |
+
assert mad.reference_profile.shape == (2,) # n_layers=1, n_heads=2 -> 2 heads total
|
| 153 |
+
|
| 154 |
+
# Test identical run (similarity should be very high, close to 1.0)
|
| 155 |
+
is_anomaly, similarity, profile = mad.detect_circuit_anomaly(
|
| 156 |
+
baseline_states, baseline_actions, baseline_returns, target_action_index=0
|
| 157 |
+
)
|
| 158 |
+
assert not is_anomaly
|
| 159 |
+
assert pytest.approx(similarity, abs=1e-5) == 1.0
|
| 160 |
+
|
| 161 |
+
# Test an anomalous run with different inputs / targets (producing different activations/attributions)
|
| 162 |
+
anomalous_states = torch.randn(1, 3, 3) + 5.0
|
| 163 |
+
is_anomaly_anom, similarity_anom, profile_anom = mad.detect_circuit_anomaly(
|
| 164 |
+
anomalous_states, baseline_actions, baseline_returns, target_action_index=1
|
| 165 |
+
)
|
| 166 |
+
# An anomaly may or may not be flagged depending on weights, but we can verify calculation correctness
|
| 167 |
+
assert similarity_anom <= 1.0
|
| 168 |
+
assert profile_anom.shape == (2,)
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
pytest.main([__file__])
|