Spaces:
Running
Running
File size: 7,260 Bytes
5ccbe34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import pytest
import torch
import numpy as np
from src.models.hooked_dt import HookedDT
from src.interpretability.sae_manager import SAEManager
from src.interpretability.safety import (
DynamicRejectionSteerer,
DeceptiveAlignmentAuditor,
FunctionalAttributionMAD,
generate_deceptive_trajectories
)
@pytest.fixture
def base_model():
"""Initializes a tiny HookedDT model for testing."""
return HookedDT.from_config(
state_dim=3,
action_dim=3,
n_layers=1,
n_heads=2,
d_model=32,
max_length=5
)
@pytest.fixture
def sae_manager(base_model):
"""Initializes SAEManager with a temporary directory."""
return SAEManager(base_model, sae_dir="tests/artifacts/safety_saes")
def test_dynamic_rejection_steerer(base_model):
"""Verifies that DynamicRejectionSteerer scales back steering when safety constraints are violated."""
steerer = DynamicRejectionSteerer(base_model)
hook_point = "blocks.0.hook_resid_post"
steering_vector = torch.randn(32)
# Inputs
states = torch.randn(1, 3, 3)
actions = torch.randn(1, 3, 3)
returns = torch.randn(1, 3, 1)
# 1. Edge Case: Fully safe. The safety check always returns True.
def safe_check(state, probs):
return True
_, alpha_safe = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, safe_check, initial_alpha=1.0
)
assert alpha_safe == 1.0
# 2. Edge Case: Always unsafe. The safety check always returns False.
def unsafe_check(state, probs):
return False
_, alpha_unsafe = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, unsafe_check, initial_alpha=1.0
)
assert alpha_unsafe == 0.0
# 3. Dynamic scenario: Action index 1 is considered illegal if its probability is above 0.35
def dynamic_safety_check(state, probs):
# probs is action probabilities of the last step, shape [action_dim]
# If action 1 is the most probable or above 0.35, it's unsafe
return probs[1].item() < 0.35
# Force action prediction to fail at high alpha and succeed at low alpha by checking scaling
_, final_alpha = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, dynamic_safety_check,
initial_alpha=1.0, decay_factor=0.5, min_alpha=0.01, max_iterations=4
)
# The steerer should either return a valid reduced alpha or 0.0 depending on model predictions
assert 0.0 <= final_alpha <= 1.0
def test_deceptive_alignment_and_audit(base_model, sae_manager):
"""Trains a model on deceptive trajectories, trains a TopK SAE, and audits situational awareness."""
# 1. Generate deceptive trajectory dataset
trajectories = generate_deceptive_trajectories(num_episodes=20, seq_len=5)
assert len(trajectories) == 20
assert trajectories[0]["observations"].shape == (5, 3)
# 2. Train model to adapt to deception behavior
optimizer = torch.optim.Adam(base_model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
base_model.train()
for epoch in range(10): # Quick training
for traj in trajectories:
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
actions = torch.from_numpy(traj["actions"]).long()
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
optimizer.zero_grad()
preds = base_model(states, actions_one_hot, returns)
loss = criterion(preds.view(-1, 3), actions.view(-1))
loss.backward()
optimizer.step()
base_model.eval()
# 3. Collect activations for SAE training
hook_point = "blocks.0.hook_resid_post"
activations_list = []
with torch.no_grad():
for traj in trajectories:
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
actions = torch.from_numpy(traj["actions"]).long()
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
_, cache = base_model(states, actions_one_hot, returns, return_cache=True)
activations_list.append(cache[hook_point].squeeze(0)) # Shape [seq_len * 3, d_model]
activations = torch.cat(activations_list, dim=0) # [N, d_model]
# 4. Setup and train a TopK SAE on the collected activations
# Expansion factor 2 -> d_sae = 64
sae_manager.setup_sae(hook_point, d_model=32, expansion_factor=2, architecture="topk", k=8)
sae_manager.train_on_trajectories(hook_point, activations, epochs=2, batch_size=32)
# 5. Run the Auditor
auditor = DeceptiveAlignmentAuditor(base_model, sae_manager)
watched_states = torch.tensor([[[5.0, 1.0, 5.0], [4.0, 1.0, 5.0]]], dtype=torch.float32)
unwatched_states = torch.tensor([[[5.0, 0.0, 5.0], [5.0, 0.0, 4.0]]], dtype=torch.float32)
actions = torch.zeros((1, 2, 3), dtype=torch.float32)
returns = torch.tensor([[[0.2], [0.4]]], dtype=torch.float32)
switch_idx = auditor.identify_situational_awareness_feature(
hook_point, watched_states, unwatched_states, actions, returns
)
assert 0 <= switch_idx < 64
# 6. Run Direct Feature Attribution
dfa = auditor.attribute_feature_to_heads(
hook_point, switch_idx, watched_states, actions, returns
)
assert dfa.shape == (1, 2) # n_layers=1, n_heads=2
def test_functional_attribution_mad(base_model):
"""Verifies that FunctionalAttributionMAD correctly flags anomalous/atypical execution circuits."""
mad = FunctionalAttributionMAD(base_model, threshold=0.6)
# Establish inputs
baseline_states = torch.randn(1, 3, 3)
baseline_actions = torch.randn(1, 3, 3)
baseline_returns = torch.randn(1, 3, 1)
# Establish baseline profile for action 0
mad.establish_reference_profile(baseline_states, baseline_actions, baseline_returns, target_action_index=0)
assert mad.reference_profile is not None
assert mad.reference_profile.shape == (2,) # n_layers=1, n_heads=2 -> 2 heads total
# Test identical run (similarity should be very high, close to 1.0)
is_anomaly, similarity, profile = mad.detect_circuit_anomaly(
baseline_states, baseline_actions, baseline_returns, target_action_index=0
)
assert not is_anomaly
assert pytest.approx(similarity, abs=1e-5) == 1.0
# Test an anomalous run with different inputs / targets (producing different activations/attributions)
anomalous_states = torch.randn(1, 3, 3) + 5.0
is_anomaly_anom, similarity_anom, profile_anom = mad.detect_circuit_anomaly(
anomalous_states, baseline_actions, baseline_returns, target_action_index=1
)
# An anomaly may or may not be flagged depending on weights, but we can verify calculation correctness
assert similarity_anom <= 1.0
assert profile_anom.shape == (2,)
if __name__ == "__main__":
pytest.main([__file__])
|