Spaces:
Running
Running
| import pytest | |
| import torch | |
| import numpy as np | |
| from src.models.hooked_dt import HookedDT | |
| from src.interpretability.sae_manager import SAEManager | |
| from src.interpretability.safety import ( | |
| DynamicRejectionSteerer, | |
| DeceptiveAlignmentAuditor, | |
| FunctionalAttributionMAD, | |
| generate_deceptive_trajectories | |
| ) | |
| def base_model(): | |
| """Initializes a tiny HookedDT model for testing.""" | |
| return HookedDT.from_config( | |
| state_dim=3, | |
| action_dim=3, | |
| n_layers=1, | |
| n_heads=2, | |
| d_model=32, | |
| max_length=5 | |
| ) | |
| def sae_manager(base_model): | |
| """Initializes SAEManager with a temporary directory.""" | |
| return SAEManager(base_model, sae_dir="tests/artifacts/safety_saes") | |
| def test_dynamic_rejection_steerer(base_model): | |
| """Verifies that DynamicRejectionSteerer scales back steering when safety constraints are violated.""" | |
| steerer = DynamicRejectionSteerer(base_model) | |
| hook_point = "blocks.0.hook_resid_post" | |
| steering_vector = torch.randn(32) | |
| # Inputs | |
| states = torch.randn(1, 3, 3) | |
| actions = torch.randn(1, 3, 3) | |
| returns = torch.randn(1, 3, 1) | |
| # 1. Edge Case: Fully safe. The safety check always returns True. | |
| def safe_check(state, probs): | |
| return True | |
| _, alpha_safe = steerer.steer_safely( | |
| states, actions, returns, hook_point, steering_vector, safe_check, initial_alpha=1.0 | |
| ) | |
| assert alpha_safe == 1.0 | |
| # 2. Edge Case: Always unsafe. The safety check always returns False. | |
| def unsafe_check(state, probs): | |
| return False | |
| _, alpha_unsafe = steerer.steer_safely( | |
| states, actions, returns, hook_point, steering_vector, unsafe_check, initial_alpha=1.0 | |
| ) | |
| assert alpha_unsafe == 0.0 | |
| # 3. Dynamic scenario: Action index 1 is considered illegal if its probability is above 0.35 | |
| def dynamic_safety_check(state, probs): | |
| # probs is action probabilities of the last step, shape [action_dim] | |
| # If action 1 is the most probable or above 0.35, it's unsafe | |
| return probs[1].item() < 0.35 | |
| # Force action prediction to fail at high alpha and succeed at low alpha by checking scaling | |
| _, final_alpha = steerer.steer_safely( | |
| states, actions, returns, hook_point, steering_vector, dynamic_safety_check, | |
| initial_alpha=1.0, decay_factor=0.5, min_alpha=0.01, max_iterations=4 | |
| ) | |
| # The steerer should either return a valid reduced alpha or 0.0 depending on model predictions | |
| assert 0.0 <= final_alpha <= 1.0 | |
| def test_deceptive_alignment_and_audit(base_model, sae_manager): | |
| """Trains a model on deceptive trajectories, trains a TopK SAE, and audits situational awareness.""" | |
| # 1. Generate deceptive trajectory dataset | |
| trajectories = generate_deceptive_trajectories(num_episodes=20, seq_len=5) | |
| assert len(trajectories) == 20 | |
| assert trajectories[0]["observations"].shape == (5, 3) | |
| # 2. Train model to adapt to deception behavior | |
| optimizer = torch.optim.Adam(base_model.parameters(), lr=0.01) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| base_model.train() | |
| for epoch in range(10): # Quick training | |
| for traj in trajectories: | |
| states = torch.from_numpy(traj["observations"]).float().unsqueeze(0) | |
| actions = torch.from_numpy(traj["actions"]).long() | |
| actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0) | |
| returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1) | |
| optimizer.zero_grad() | |
| preds = base_model(states, actions_one_hot, returns) | |
| loss = criterion(preds.view(-1, 3), actions.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| base_model.eval() | |
| # 3. Collect activations for SAE training | |
| hook_point = "blocks.0.hook_resid_post" | |
| activations_list = [] | |
| with torch.no_grad(): | |
| for traj in trajectories: | |
| states = torch.from_numpy(traj["observations"]).float().unsqueeze(0) | |
| actions = torch.from_numpy(traj["actions"]).long() | |
| actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=3).float().unsqueeze(0) | |
| returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1) | |
| _, cache = base_model(states, actions_one_hot, returns, return_cache=True) | |
| activations_list.append(cache[hook_point].squeeze(0)) # Shape [seq_len * 3, d_model] | |
| activations = torch.cat(activations_list, dim=0) # [N, d_model] | |
| # 4. Setup and train a TopK SAE on the collected activations | |
| # Expansion factor 2 -> d_sae = 64 | |
| sae_manager.setup_sae(hook_point, d_model=32, expansion_factor=2, architecture="topk", k=8) | |
| sae_manager.train_on_trajectories(hook_point, activations, epochs=2, batch_size=32) | |
| # 5. Run the Auditor | |
| auditor = DeceptiveAlignmentAuditor(base_model, sae_manager) | |
| watched_states = torch.tensor([[[5.0, 1.0, 5.0], [4.0, 1.0, 5.0]]], dtype=torch.float32) | |
| unwatched_states = torch.tensor([[[5.0, 0.0, 5.0], [5.0, 0.0, 4.0]]], dtype=torch.float32) | |
| actions = torch.zeros((1, 2, 3), dtype=torch.float32) | |
| returns = torch.tensor([[[0.2], [0.4]]], dtype=torch.float32) | |
| switch_idx = auditor.identify_situational_awareness_feature( | |
| hook_point, watched_states, unwatched_states, actions, returns | |
| ) | |
| assert 0 <= switch_idx < 64 | |
| # 6. Run Direct Feature Attribution | |
| dfa = auditor.attribute_feature_to_heads( | |
| hook_point, switch_idx, watched_states, actions, returns | |
| ) | |
| assert dfa.shape == (1, 2) # n_layers=1, n_heads=2 | |
| def test_functional_attribution_mad(base_model): | |
| """Verifies that FunctionalAttributionMAD correctly flags anomalous/atypical execution circuits.""" | |
| mad = FunctionalAttributionMAD(base_model, threshold=0.6) | |
| # Establish inputs | |
| baseline_states = torch.randn(1, 3, 3) | |
| baseline_actions = torch.randn(1, 3, 3) | |
| baseline_returns = torch.randn(1, 3, 1) | |
| # Establish baseline profile for action 0 | |
| mad.establish_reference_profile(baseline_states, baseline_actions, baseline_returns, target_action_index=0) | |
| assert mad.reference_profile is not None | |
| assert mad.reference_profile.shape == (2,) # n_layers=1, n_heads=2 -> 2 heads total | |
| # Test identical run (similarity should be very high, close to 1.0) | |
| is_anomaly, similarity, profile = mad.detect_circuit_anomaly( | |
| baseline_states, baseline_actions, baseline_returns, target_action_index=0 | |
| ) | |
| assert not is_anomaly | |
| assert pytest.approx(similarity, abs=1e-5) == 1.0 | |
| # Test an anomalous run with different inputs / targets (producing different activations/attributions) | |
| anomalous_states = torch.randn(1, 3, 3) + 5.0 | |
| is_anomaly_anom, similarity_anom, profile_anom = mad.detect_circuit_anomaly( | |
| anomalous_states, baseline_actions, baseline_returns, target_action_index=1 | |
| ) | |
| # An anomaly may or may not be flagged depending on weights, but we can verify calculation correctness | |
| assert similarity_anom <= 1.0 | |
| assert profile_anom.shape == (2,) | |
| if __name__ == "__main__": | |
| pytest.main([__file__]) | |