DT-Explorer / tests /test_adversarial_control.py
sadhumitha-s's picture
feat: implement safety auditing tools for steering and deceptive alignment detection
5ccbe34
import pytest
import torch
import numpy as np
from src.models.hooked_dt import HookedDT
from src.interpretability.sae_manager import SAEManager
from src.interpretability.safety import (
DynamicRejectionSteerer,
DeceptiveAlignmentAuditor,
FunctionalAttributionMAD,
generate_deceptive_trajectories
)
@pytest.fixture
def base_model():
"""Initializes a tiny HookedDT model for testing."""
return HookedDT.from_config(
state_dim=3,
action_dim=3,
n_layers=1,
n_heads=2,
d_model=32,
max_length=5
)
@pytest.fixture
def sae_manager(base_model):
"""Initializes SAEManager with a temporary directory."""
return SAEManager(base_model, sae_dir="tests/artifacts/safety_saes")
def test_dynamic_rejection_steerer(base_model):
"""Verifies that DynamicRejectionSteerer scales back steering when safety constraints are violated."""
steerer = DynamicRejectionSteerer(base_model)
hook_point = "blocks.0.hook_resid_post"
steering_vector = torch.randn(32)
# Inputs
states = torch.randn(1, 3, 3)
actions = torch.randn(1, 3, 3)
returns = torch.randn(1, 3, 1)
# 1. Edge Case: Fully safe. The safety check always returns True.
def safe_check(state, probs):
return True
_, alpha_safe = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, safe_check, initial_alpha=1.0
)
assert alpha_safe == 1.0
# 2. Edge Case: Always unsafe. The safety check always returns False.
def unsafe_check(state, probs):
return False
_, alpha_unsafe = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, unsafe_check, initial_alpha=1.0
)
assert alpha_unsafe == 0.0
# 3. Dynamic scenario: Action index 1 is considered illegal if its probability is above 0.35
def dynamic_safety_check(state, probs):
# probs is action probabilities of the last step, shape [action_dim]
# If action 1 is the most probable or above 0.35, it's unsafe
return probs[1].item() < 0.35
# Force action prediction to fail at high alpha and succeed at low alpha by checking scaling
_, final_alpha = steerer.steer_safely(
states, actions, returns, hook_point, steering_vector, dynamic_safety_check,
initial_alpha=1.0, decay_factor=0.5, min_alpha=0.01, max_iterations=4
)
# The steerer should either return a valid reduced alpha or 0.0 depending on model predictions
assert 0.0 <= final_alpha <= 1.0
def test_deceptive_alignment_and_audit(base_model, sae_manager):
"""Trains a model on deceptive trajectories, trains a TopK SAE, and audits situational awareness."""
# 1. Generate deceptive trajectory dataset
trajectories = generate_deceptive_trajectories(num_episodes=20, seq_len=5)
assert len(trajectories) == 20
assert trajectories[0]["observations"].shape == (5, 3)
# 2. Train model to adapt to deception behavior
optimizer = torch.optim.Adam(base_model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
base_model.train()
for epoch in range(10): # Quick training
for traj in trajectories:
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
actions = torch.from_numpy(traj["actions"]).long()
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
optimizer.zero_grad()
preds = base_model(states, actions_one_hot, returns)
loss = criterion(preds.view(-1, 3), actions.view(-1))
loss.backward()
optimizer.step()
base_model.eval()
# 3. Collect activations for SAE training
hook_point = "blocks.0.hook_resid_post"
activations_list = []
with torch.no_grad():
for traj in trajectories:
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
actions = torch.from_numpy(traj["actions"]).long()
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0)
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
_, cache = base_model(states, actions_one_hot, returns, return_cache=True)
activations_list.append(cache[hook_point].squeeze(0)) # Shape [seq_len * 3, d_model]
activations = torch.cat(activations_list, dim=0) # [N, d_model]
# 4. Setup and train a TopK SAE on the collected activations
# Expansion factor 2 -> d_sae = 64
sae_manager.setup_sae(hook_point, d_model=32, expansion_factor=2, architecture="topk", k=8)
sae_manager.train_on_trajectories(hook_point, activations, epochs=2, batch_size=32)
# 5. Run the Auditor
auditor = DeceptiveAlignmentAuditor(base_model, sae_manager)
watched_states = torch.tensor([[[5.0, 1.0, 5.0], [4.0, 1.0, 5.0]]], dtype=torch.float32)
unwatched_states = torch.tensor([[[5.0, 0.0, 5.0], [5.0, 0.0, 4.0]]], dtype=torch.float32)
actions = torch.zeros((1, 2, 3), dtype=torch.float32)
returns = torch.tensor([[[0.2], [0.4]]], dtype=torch.float32)
switch_idx = auditor.identify_situational_awareness_feature(
hook_point, watched_states, unwatched_states, actions, returns
)
assert 0 <= switch_idx < 64
# 6. Run Direct Feature Attribution
dfa = auditor.attribute_feature_to_heads(
hook_point, switch_idx, watched_states, actions, returns
)
assert dfa.shape == (1, 2) # n_layers=1, n_heads=2
def test_functional_attribution_mad(base_model):
"""Verifies that FunctionalAttributionMAD correctly flags anomalous/atypical execution circuits."""
mad = FunctionalAttributionMAD(base_model, threshold=0.6)
# Establish inputs
baseline_states = torch.randn(1, 3, 3)
baseline_actions = torch.randn(1, 3, 3)
baseline_returns = torch.randn(1, 3, 1)
# Establish baseline profile for action 0
mad.establish_reference_profile(baseline_states, baseline_actions, baseline_returns, target_action_index=0)
assert mad.reference_profile is not None
assert mad.reference_profile.shape == (2,) # n_layers=1, n_heads=2 -> 2 heads total
# Test identical run (similarity should be very high, close to 1.0)
is_anomaly, similarity, profile = mad.detect_circuit_anomaly(
baseline_states, baseline_actions, baseline_returns, target_action_index=0
)
assert not is_anomaly
assert pytest.approx(similarity, abs=1e-5) == 1.0
# Test an anomalous run with different inputs / targets (producing different activations/attributions)
anomalous_states = torch.randn(1, 3, 3) + 5.0
is_anomaly_anom, similarity_anom, profile_anom = mad.detect_circuit_anomaly(
anomalous_states, baseline_actions, baseline_returns, target_action_index=1
)
# An anomaly may or may not be flagged depending on weights, but we can verify calculation correctness
assert similarity_anom <= 1.0
assert profile_anom.shape == (2,)
if __name__ == "__main__":
pytest.main([__file__])