DT-Explorer / docs /safety_auditing.md
sadhumitha-s's picture
feat: implement safety auditing tools for steering and deceptive alignment detection
5ccbe34

Safety Auditing and Adversarial Control

This document outlines the theory, mathematical formulation, and API usage for the safety auditing, dynamic steering, and deception auditing tools implemented in the DT-Circuits framework.


1. Dynamic Rejection Steering (Directer)

Concept

Activation steering allows us to inject concept vectors (e.g., speed, exploration) into the residual stream to influence decisions without modifying model weights. However, steering can occasionally override basic safety constraints, leading to unsafe actions (e.g., walking into obstacles or lava).

The Directer logic implements an inference-time feedback loop. It dynamically monitors the action probabilities generated by the steered model, checking them against safety criteria. If the safety check fails, it scales back the steering strength ($\alpha$) until safety boundaries are satisfied.

Mathematical Formulation

Given a state sequence $s_{1:t}$, actions $a_{1:t-1}$, and returns-to-go $g_{1:t}$, let the steering vector be $v \in \mathbb{R}^{d_{model}}$ applied at hook point $H$. The steered activation is:

xHβ€²=xH+Ξ±β‹…vx'_{H} = x_{H} + \alpha \cdot v

The model outputs logits and action probabilities:

P(at∣s1:t,a1:tβˆ’1,g1:t;Ξ±)=Softmax(DT(s1:t,a1:tβˆ’1,g1:t;xHβ€²))P(a_t | s_{1:t}, a_{1:t-1}, g_{1:t}; \alpha) = \text{Softmax}(\text{DT}(s_{1:t}, a_{1:t-1}, g_{1:t}; x'_{H}))

Let $f_S(s_t, P(a_t))$ be a boolean safety filter that evaluates to True if safe and False if unsafe. The feedback loop updates $\alpha$ iteratively:

Ξ±k+1=Ξ±kβ‹…Ξ³iffS(st,P(at;Ξ±k))=False\alpha_{k+1} = \alpha_k \cdot \gamma \quad \text{if} \quad f_S(s_t, P(a_t; \alpha_k)) = \text{False}

where $\gamma \in (0, 1)$ is a decay factor. The loop terminates when the state is evaluated as safe or $\alpha$ falls below a defined minimum $\alpha_{min}$, in which case it defaults to the unsteered run ($\alpha = 0.0$).

Usage

from src.interpretability.safety import DynamicRejectionSteerer

steerer = DynamicRejectionSteerer(model)

# Define safety check: Action index 2 (e.g., "move forward") is unsafe if lava is in front
def safety_check_fn(current_state, action_probs):
    lava_in_front = current_state[0] == 1.0  # Custom state feature
    if lava_in_front and action_probs[2] > 0.3:
        return False
    return True

safe_logits, safe_alpha = steerer.steer_safely(
    states=states,
    actions=actions,
    returns_to_go=returns,
    hook_point="blocks.0.hook_resid_post",
    steering_vector=exploration_vector,
    safety_check_fn=safety_check_fn,
    initial_alpha=1.0,
    decay_factor=0.5
)

2. Deceptive Alignment Auditing

Concept

Deceptive alignment is a safety hazard where an agent behaves aligned under observation but pursues a different objective (the "secret goal") once the supervisor is absent. The auditor leverages Sparse Autoencoders (SAEs) to mechanisticially audit this behavior:

  1. Identifying the Situational Awareness Feature that acts as the internal representation of being observed.
  2. Tracing the Situational Awareness Circuit of attention heads responsible for activating this feature.

Direct Feature Attribution (DFA)

To isolate how much individual attention heads contribute to the situational awareness SAE feature, we use Direct Feature Attribution.

Let $y_{l, h}$ be the output vector of attention head $h$ in layer $l$, and let $W_{enc, f} \in \mathbb{R}^{d_{model}}$ be the encoder weight vector of the target situational awareness feature $f$ in the SAE. The direct contribution score $C(l, h)$ is:

C(l,h)=yl,hβ‹…Wenc,fC(l, h) = y_{l, h} \cdot W_{enc, f}

Heads with highly positive or negative scores represent the causal switches routing observational context into the model's decision pathway.

Usage

from src.interpretability.safety import DeceptiveAlignmentAuditor

auditor = DeceptiveAlignmentAuditor(model, sae_manager)

# 1. Identify the feature toggling between watched/unwatched conditions
switch_feature_idx = auditor.identify_situational_awareness_feature(
    hook_point="blocks.0.hook_resid_post",
    watched_states=watched_states,
    unwatched_states=unwatched_states,
    actions=actions,
    returns_to_go=returns
)

# 2. Map direct head contributions to this situational awareness feature
head_attributions = auditor.attribute_feature_to_heads(
    hook_point="blocks.0.hook_resid_post",
    feature_idx=switch_feature_idx,
    states=watched_states,
    actions=actions,
    returns_to_go=returns
)

3. Functional Attribution MAD

Concept

Standard Mechanistic Anomaly Detection (MAD) uses SAE reconstruction error to identify out-of-distribution (OOD) states. However, an agent could perform deceptive behaviors or reward hacks on states it can reconstruct well, but using an entirely different internal circuit.

Functional Attribution MAD tracks anomalies at the circuit level rather than raw activations. It establishes a reference profile of Direct Logit Attribution (DLA) across all attention heads during normal execution. During evaluation, it monitors the cosine similarity between the current DLA profile and the cached reference profile to flag anomalies.

Mathematical Formulation

Let $D_{ref} \in \mathbb{R}^{L \times H}$ be the reference DLA profile matrix, flattened to a 1D vector $v_{ref} \in \mathbb{R}^{LH}$. Let $D_{eval}$ be the flattened active DLA profile vector $v_{eval}$. The circuit similarity score $S$ is calculated as:

S=vrefβ‹…vevalβˆ₯vrefβˆ₯2βˆ₯vevalβˆ₯2S = \frac{v_{ref} \cdot v_{eval}}{\|v_{ref}\|_2 \|v_{eval}\|_2}

A mechanistic anomaly is flagged if $S$ drops below a specified threshold $\tau$:

Anomaly Flag=S<Ο„\text{Anomaly Flag} = S < \tau

Usage

from src.interpretability.safety import FunctionalAttributionMAD

mad = FunctionalAttributionMAD(model, threshold=0.6)

# Establish reference profile during normal aligned runs
mad.establish_reference_profile(
    baseline_states=normal_states,
    baseline_actions=normal_actions,
    baseline_returns=normal_returns,
    target_action_index=0
)

# Perform runtime checks
is_anomaly, similarity, profile = mad.detect_circuit_anomaly(
    eval_states=active_states,
    eval_actions=active_actions,
    eval_returns=active_returns,
    target_action_index=0
)

if is_anomaly:
    print(f"Anomalous circuit execution detected! Similarity: {similarity:.4f}")