Spaces:
Running
Running
Commit ·
0346604
1
Parent(s): 4d7c719
feat: implement SAE manager for latent decomposition and steering library for contrastive activation addition
Browse files- README.md +74 -29
- src/__init__.py +0 -0
- src/dashboard/__init__.py +0 -0
- src/interpretability/__init__.py +0 -0
- src/interpretability/sae_manager.py +155 -0
- src/interpretability/steering.py +59 -24
- src/models/__init__.py +0 -0
- src/utils/__init__.py +0 -0
- tests/test_sae_and_steering.py +75 -0
README.md
CHANGED
|
@@ -1,44 +1,89 @@
|
|
| 1 |
-
# DT-Circuits
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
- **Data**: PPO Trajectory Harvester for high-quality teacher data.
|
| 7 |
-
- **Model**: `HookedDT` - A custom Decision Transformer wrapped in `TransformerLens` for full activation visibility.
|
| 8 |
-
- **Interpretability**: Tools for Direct Logit Attribution (DLA), Activation Patching, and Induction Head detection.
|
| 9 |
-
- **Dashboard**: Streamlit-based UI for real-time causal interventions.
|
| 10 |
|
| 11 |
-
##
|
| 12 |
|
| 13 |
-
### 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
```bash
|
| 15 |
pip install -r requirements.txt
|
| 16 |
```
|
| 17 |
|
| 18 |
-
###
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
```
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
## Testing
|
| 29 |
-
|
| 30 |
```bash
|
| 31 |
-
pytest tests/
|
| 32 |
```
|
| 33 |
|
| 34 |
-
##
|
| 35 |
-
- `src/data/harvester.py`: Collects trajectories from MiniGrid.
|
| 36 |
-
- `src/models/hooked_dt.py`: Hookable transformer implementation.
|
| 37 |
-
- `src/interpretability/`:
|
| 38 |
-
- `attribution.py`: Direct Logit Attribution logic.
|
| 39 |
-
- `patching.py`: Activation patching interface.
|
| 40 |
-
- `induction_scan.py`: Automated circuit discovery.
|
| 41 |
-
|
| 42 |
-
## License
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DT-Circuits: Mechanistic Interpretability for Decision Transformers
|
| 2 |
|
| 3 |
+
DT-Circuits is a research-grade framework designed for the rigorous mechanistic interpretability of Decision Transformers (DT). By leveraging the TransformerLens paradigm, this platform enables researchers to map internal neural circuits, decompose activations using Sparse Autoencoders, and perform causal interventions on agent decision-making.
|
| 4 |
|
| 5 |
+
The primary objective is to move beyond behavioral observation and saliency maps toward a quantitative understanding of how Reward-to-Go, State, and Action tokens are processed within the residual stream.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
## Core Capabilities
|
| 8 |
|
| 9 |
+
### 1. Circuit Foundation
|
| 10 |
+
- **Hooked-DT Architecture**: A custom Decision Transformer implementation wrapped in TransformerLens, providing full access to internal activations, weights, and the residual stream.
|
| 11 |
+
- **Direct Logit Attribution (DLA)**: Quantitative mapping of individual attention heads and MLP layers to the final action logits.
|
| 12 |
+
- **Induction Head Discovery**: Automated scanning tools to identify heads responsible for temporal pattern recognition and "memory" in RL tasks.
|
| 13 |
+
|
| 14 |
+
### 2. Causal Interventions
|
| 15 |
+
- **Activation Patching**: Surgical replacement of activations between "clean" and "corrupted" runs to identify bottleneck features and causal paths.
|
| 16 |
+
- **Contrastive Activation Addition (CAA)**: Generation of steering vectors by calculating the mean difference between positive and negative activation sets.
|
| 17 |
+
- **Steering Library**: A persistent library of pre-calculated vectors (e.g., success_vector, exploration_vector) that can be injected at inference time to manipulate agent behavior without retraining.
|
| 18 |
+
|
| 19 |
+
### 3. Deep Discovery & Safety
|
| 20 |
+
- **Sparse Autoencoder (SAE) Integration**: Tools to train and deploy SAEs on the residual stream, decomposing polysemantic neurons into monosemantic latents.
|
| 21 |
+
- **Mechanistic Anomaly Detection**: Utilizing SAE reconstruction error as a high-fidelity proxy for detecting out-of-distribution (OOD) states.
|
| 22 |
+
|
| 23 |
+
## Technical Architecture
|
| 24 |
+
|
| 25 |
+
The platform is divided into four primary layers:
|
| 26 |
+
- **Data Layer**: PPO Trajectory Harvester for generating high-quality expert demonstrations in Gymnasium environments (e.g., MiniGrid).
|
| 27 |
+
- **Model Layer**: The HookedDT implementation which maintains compatibility with standard DT architectures while adding hook-based visibility.
|
| 28 |
+
- **Interpretability Layer**: A suite of modules for attribution, patching, SAE management, and steering.
|
| 29 |
+
- **Visualization Layer**: A Streamlit-based dashboard for real-time activation monitoring and interactive steering.
|
| 30 |
+
|
| 31 |
+
## Getting Started
|
| 32 |
+
|
| 33 |
+
### Prerequisites
|
| 34 |
+
- Python 3.9+
|
| 35 |
+
- PyTorch 2.x
|
| 36 |
+
- TransformerLens
|
| 37 |
+
- SAE-Lens
|
| 38 |
+
|
| 39 |
+
### Installation
|
| 40 |
```bash
|
| 41 |
pip install -r requirements.txt
|
| 42 |
```
|
| 43 |
|
| 44 |
+
### Basic Workflow
|
| 45 |
+
1. **Generate Trajectories**:
|
| 46 |
+
Use the harvester to collect teacher data for model training or SAE feature extraction.
|
| 47 |
+
```bash
|
| 48 |
+
python scripts/train_dt.py
|
| 49 |
+
```
|
| 50 |
|
| 51 |
+
2. **Run Interpretability Dashboard**:
|
| 52 |
+
Launch the interactive UI to perform real-time patching and steering interventions.
|
| 53 |
+
```bash
|
| 54 |
+
streamlit run src/dashboard/app.py
|
| 55 |
+
```
|
| 56 |
|
| 57 |
## Testing
|
| 58 |
+
|
| 59 |
```bash
|
| 60 |
+
PYTHONPATH=. pytest tests/
|
| 61 |
```
|
| 62 |
|
| 63 |
+
## Project Structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
```text
|
| 66 |
+
DT-Circuits/
|
| 67 |
+
├── scripts/ # Training and harvesting entry points
|
| 68 |
+
│ ├── train_dt.py # Decision Transformer training pipeline
|
| 69 |
+
│ └── train_sae.py # Sparse Autoencoder (SAE) training script
|
| 70 |
+
├── src/
|
| 71 |
+
│ ├── dashboard/
|
| 72 |
+
│ │ └── app.py # Streamlit-based visualization UI
|
| 73 |
+
│ ├── data/
|
| 74 |
+
│ │ └── harvester.py # PPO-based expert trajectory harvester
|
| 75 |
+
│ ├── interpretability/
|
| 76 |
+
│ │ ├── attribution.py # Direct Logit Attribution (DLA)
|
| 77 |
+
│ │ ├── induction_scan.py # Induction head detection logic
|
| 78 |
+
��� │ ├── patching.py # Causal activation patching tools
|
| 79 |
+
│ │ ├── sae_manager.py # SAE deployment and anomaly detection
|
| 80 |
+
│ │ └── steering.py # Steering vector generation and injection
|
| 81 |
+
│ ├── models/
|
| 82 |
+
│ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
|
| 83 |
+
│ └── utils/
|
| 84 |
+
├── tests/ # Unit and integration test suite
|
| 85 |
+
│ ├── test_components.py
|
| 86 |
+
│ └── test_sae_and_steering.py
|
| 87 |
+
├── config.yaml # Experiment and environment configuration
|
| 88 |
+
└── requirements.txt # Environment dependencies
|
| 89 |
+
```
|
src/__init__.py
ADDED
|
File without changes
|
src/dashboard/__init__.py
ADDED
|
File without changes
|
src/interpretability/__init__.py
ADDED
|
File without changes
|
src/interpretability/sae_manager.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 5 |
+
from sae_lens import StandardSAE, StandardSAEConfig
|
| 6 |
+
from jaxtyping import Float
|
| 7 |
+
|
| 8 |
+
class SAEManager:
|
| 9 |
+
"""
|
| 10 |
+
Research-grade manager for Sparse Autoencoders (SAEs) integrated with Decision Transformers.
|
| 11 |
+
Handles training, decomposition into monosemantic latents, and mechanistic anomaly detection.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
|
| 14 |
+
self.model = model
|
| 15 |
+
self.sae_dir = sae_dir
|
| 16 |
+
self.saes: Dict[str, StandardSAE] = {}
|
| 17 |
+
os.makedirs(sae_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
def setup_sae(
|
| 20 |
+
self,
|
| 21 |
+
hook_point: str,
|
| 22 |
+
d_model: int,
|
| 23 |
+
expansion_factor: int = 8,
|
| 24 |
+
) -> StandardSAE:
|
| 25 |
+
"""
|
| 26 |
+
Initializes an SAE for a specific hook point in the transformer.
|
| 27 |
+
"""
|
| 28 |
+
cfg = StandardSAEConfig(
|
| 29 |
+
d_in=d_model,
|
| 30 |
+
d_sae=d_model * expansion_factor,
|
| 31 |
+
device=str(next(self.model.parameters()).device)
|
| 32 |
+
)
|
| 33 |
+
sae = StandardSAE(cfg)
|
| 34 |
+
self.saes[hook_point] = sae
|
| 35 |
+
return sae
|
| 36 |
+
|
| 37 |
+
def train_on_trajectories(
|
| 38 |
+
self,
|
| 39 |
+
hook_point: str,
|
| 40 |
+
activations: Float[torch.Tensor, "n_samples d_model"],
|
| 41 |
+
l1_coefficient: float = 0.0001,
|
| 42 |
+
batch_size: int = 1024,
|
| 43 |
+
epochs: int = 10,
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Trains the SAE on collected trajectory activations.
|
| 47 |
+
"""
|
| 48 |
+
if hook_point not in self.saes:
|
| 49 |
+
self.setup_sae(hook_point, activations.shape[-1])
|
| 50 |
+
|
| 51 |
+
sae = self.saes[hook_point]
|
| 52 |
+
optimizer = torch.optim.Adam(sae.parameters(), lr=0.0004)
|
| 53 |
+
|
| 54 |
+
sae.train()
|
| 55 |
+
n_samples = activations.shape[0]
|
| 56 |
+
|
| 57 |
+
for epoch in range(epochs):
|
| 58 |
+
permutation = torch.randperm(n_samples)
|
| 59 |
+
epoch_loss = 0
|
| 60 |
+
for i in range(0, n_samples, batch_size):
|
| 61 |
+
indices = permutation[i:i+batch_size]
|
| 62 |
+
batch_acts = activations[indices].to(sae.device)
|
| 63 |
+
|
| 64 |
+
optimizer.zero_grad()
|
| 65 |
+
|
| 66 |
+
# Manual forward pass for training
|
| 67 |
+
feature_acts = sae.encode(batch_acts)
|
| 68 |
+
sae_out = sae.decode(feature_acts)
|
| 69 |
+
|
| 70 |
+
mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
|
| 71 |
+
l1_loss = l1_coefficient * feature_acts.abs().sum()
|
| 72 |
+
loss = mse_loss + l1_loss
|
| 73 |
+
|
| 74 |
+
loss.backward()
|
| 75 |
+
optimizer.step()
|
| 76 |
+
epoch_loss += loss.item()
|
| 77 |
+
|
| 78 |
+
print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss / (n_samples / batch_size):.4f}")
|
| 79 |
+
|
| 80 |
+
def get_feature_activations(
|
| 81 |
+
self,
|
| 82 |
+
hook_point: str,
|
| 83 |
+
activations: Float[torch.Tensor, "... d_model"]
|
| 84 |
+
) -> Float[torch.Tensor, "... d_sae"]:
|
| 85 |
+
"""
|
| 86 |
+
Decomposes activations into monosemantic features.
|
| 87 |
+
"""
|
| 88 |
+
if hook_point not in self.saes:
|
| 89 |
+
raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
|
| 90 |
+
|
| 91 |
+
sae = self.saes[hook_point]
|
| 92 |
+
sae.eval()
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
feature_acts = sae.encode(activations.to(sae.device))
|
| 95 |
+
return feature_acts
|
| 96 |
+
|
| 97 |
+
def reconstruct(
|
| 98 |
+
self,
|
| 99 |
+
hook_point: str,
|
| 100 |
+
activations: Float[torch.Tensor, "... d_model"]
|
| 101 |
+
) -> Float[torch.Tensor, "... d_model"]:
|
| 102 |
+
"""
|
| 103 |
+
Reconstructs the original activations using the SAE.
|
| 104 |
+
"""
|
| 105 |
+
if hook_point not in self.saes:
|
| 106 |
+
raise ValueError(f"SAE for {hook_point} not found.")
|
| 107 |
+
|
| 108 |
+
sae = self.saes[hook_point]
|
| 109 |
+
sae.eval()
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
feature_acts = sae.encode(activations.to(sae.device))
|
| 112 |
+
sae_out = sae.decode(feature_acts)
|
| 113 |
+
return sae_out
|
| 114 |
+
|
| 115 |
+
def compute_anomaly_score(
|
| 116 |
+
self,
|
| 117 |
+
hook_point: str,
|
| 118 |
+
activations: Float[torch.Tensor, "... d_model"]
|
| 119 |
+
) -> Float[torch.Tensor, "..."]:
|
| 120 |
+
"""
|
| 121 |
+
Calculates reconstruction error as a proxy for mechanistic anomaly detection.
|
| 122 |
+
Formula: ||x - x_hat|| / ||x||
|
| 123 |
+
"""
|
| 124 |
+
if hook_point not in self.saes:
|
| 125 |
+
raise ValueError(f"SAE for {hook_point} not found.")
|
| 126 |
+
|
| 127 |
+
sae = self.saes[hook_point]
|
| 128 |
+
sae.eval()
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
x = activations.to(sae.device)
|
| 131 |
+
feature_acts = sae.encode(x)
|
| 132 |
+
x_hat = sae.decode(feature_acts)
|
| 133 |
+
|
| 134 |
+
error = torch.norm(x - x_hat, dim=-1) / (torch.norm(x, dim=-1) + 1e-8)
|
| 135 |
+
return error
|
| 136 |
+
|
| 137 |
+
def save_all_saes(self):
|
| 138 |
+
for hook, sae in self.saes.items():
|
| 139 |
+
path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
|
| 140 |
+
torch.save({
|
| 141 |
+
'state_dict': sae.state_dict(),
|
| 142 |
+
'cfg': sae.cfg
|
| 143 |
+
}, path)
|
| 144 |
+
print(f"Saved SAE for {hook} to {path}")
|
| 145 |
+
|
| 146 |
+
def load_sae(self, hook_point: str):
|
| 147 |
+
path = os.path.join(self.sae_dir, f"{hook_point.replace('.', '_')}_sae.pt")
|
| 148 |
+
if not os.path.exists(path):
|
| 149 |
+
raise FileNotFoundError(f"No saved SAE found at {path}")
|
| 150 |
+
|
| 151 |
+
checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device))
|
| 152 |
+
sae = StandardSAE(checkpoint['cfg'])
|
| 153 |
+
sae.load_state_dict(checkpoint['state_dict'])
|
| 154 |
+
self.saes[hook_point] = sae
|
| 155 |
+
return sae
|
src/interpretability/steering.py
CHANGED
|
@@ -1,43 +1,78 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class RTGSteerer:
|
| 6 |
"""
|
| 7 |
-
Enables 'Behavioral Steering' by manipulating Reward-to-Go (RTG) tokens.
|
|
|
|
| 8 |
"""
|
| 9 |
-
def __init__(self, model):
|
| 10 |
self.model = model
|
|
|
|
| 11 |
|
| 12 |
-
def
|
| 13 |
self,
|
| 14 |
-
states: torch.Tensor,
|
| 15 |
-
actions: torch.Tensor,
|
| 16 |
base_rtg: torch.Tensor,
|
| 17 |
-
|
|
|
|
| 18 |
alpha: float = 1.0
|
| 19 |
-
):
|
| 20 |
"""
|
| 21 |
Adds a steering vector to the RTG embeddings.
|
| 22 |
-
RTG_new = RTG_base + alpha * steering_vector
|
| 23 |
"""
|
| 24 |
-
|
|
|
|
| 25 |
with torch.no_grad():
|
| 26 |
rtg_emb = self.model.embed_return(base_rtg)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
def
|
| 37 |
"""
|
| 38 |
-
|
| 39 |
-
Vector = Mean(High Reward Residual) - Mean(Low Reward Residual)
|
| 40 |
"""
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
class SteeringLibrary:
|
| 5 |
+
"""
|
| 6 |
+
A persistent library of pre-calculated steering vectors (CAA).
|
| 7 |
+
Includes vectors for exploration, safety, and goal-directedness.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, d_model: int):
|
| 10 |
+
self.d_model = d_model
|
| 11 |
+
self.vectors: Dict[str, torch.Tensor] = {}
|
| 12 |
+
|
| 13 |
+
def add_vector(self, name: str, vector: torch.Tensor):
|
| 14 |
+
if vector.shape[-1] != self.d_model:
|
| 15 |
+
raise ValueError(f"Vector dimension {vector.shape[-1]} does not match d_model {self.d_model}")
|
| 16 |
+
self.vectors[name] = vector
|
| 17 |
+
|
| 18 |
+
def get_vector(self, name: str) -> torch.Tensor:
|
| 19 |
+
if name not in self.vectors:
|
| 20 |
+
raise KeyError(f"Vector '{name}' not found in library.")
|
| 21 |
+
return self.vectors[name]
|
| 22 |
+
|
| 23 |
+
def list_vectors(self) -> List[str]:
|
| 24 |
+
return list(self.vectors.keys())
|
| 25 |
|
| 26 |
class RTGSteerer:
|
| 27 |
"""
|
| 28 |
+
Enables 'Behavioral Steering' by manipulating Reward-to-Go (RTG) tokens or internal activations.
|
| 29 |
+
Supports Contrastive Activation Addition (CAA).
|
| 30 |
"""
|
| 31 |
+
def __init__(self, model, library: Optional[SteeringLibrary] = None):
|
| 32 |
self.model = model
|
| 33 |
+
self.library = library or SteeringLibrary(model.cfg.d_model)
|
| 34 |
|
| 35 |
+
def steer_rtg(
|
| 36 |
self,
|
|
|
|
|
|
|
| 37 |
base_rtg: torch.Tensor,
|
| 38 |
+
vector_name: Optional[str] = None,
|
| 39 |
+
custom_vector: Optional[torch.Tensor] = None,
|
| 40 |
alpha: float = 1.0
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
"""
|
| 43 |
Adds a steering vector to the RTG embeddings.
|
|
|
|
| 44 |
"""
|
| 45 |
+
vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name)
|
| 46 |
+
|
| 47 |
with torch.no_grad():
|
| 48 |
rtg_emb = self.model.embed_return(base_rtg)
|
| 49 |
+
return rtg_emb + alpha * vector
|
| 50 |
+
|
| 51 |
+
def generate_caa_vector(
|
| 52 |
+
self,
|
| 53 |
+
positive_activations: torch.Tensor,
|
| 54 |
+
negative_activations: torch.Tensor,
|
| 55 |
+
method: str = "mean_diff"
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Generates a steering vector using Contrastive Activation Addition.
|
| 59 |
+
'mean_diff' calculates the difference between the means of positive and negative sets.
|
| 60 |
+
"""
|
| 61 |
+
if method == "mean_diff":
|
| 62 |
+
pos_mean = positive_activations.mean(dim=0)
|
| 63 |
+
neg_mean = negative_activations.mean(dim=0)
|
| 64 |
+
return pos_mean - neg_mean
|
| 65 |
+
else:
|
| 66 |
+
raise NotImplementedError(f"Method {method} not implemented.")
|
| 67 |
|
| 68 |
+
def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0):
|
| 69 |
"""
|
| 70 |
+
Returns a HookedTransformer compatible hook function that applies steering.
|
|
|
|
| 71 |
"""
|
| 72 |
+
vector = self.library.get_vector(vector_name)
|
| 73 |
+
|
| 74 |
+
def steering_hook(activations, hook):
|
| 75 |
+
# activations: [batch, pos, d_model]
|
| 76 |
+
return activations + alpha * vector
|
| 77 |
+
|
| 78 |
+
return steering_hook
|
src/models/__init__.py
ADDED
|
File without changes
|
src/utils/__init__.py
ADDED
|
File without changes
|
tests/test_sae_and_steering.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import torch
|
| 3 |
+
from src.models.hooked_dt import HookedDT
|
| 4 |
+
from src.interpretability.sae_manager import SAEManager
|
| 5 |
+
from src.interpretability.steering import RTGSteerer, SteeringLibrary
|
| 6 |
+
|
| 7 |
+
@pytest.fixture
|
| 8 |
+
def model():
|
| 9 |
+
return HookedDT.from_config(state_dim=10, action_dim=5, n_layers=1, n_heads=2, d_model=32)
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def sae_manager(model):
|
| 13 |
+
return SAEManager(model)
|
| 14 |
+
|
| 15 |
+
def test_sae_lifecycle(sae_manager):
|
| 16 |
+
hook_point = "blocks.0.hook_resid_post"
|
| 17 |
+
d_model = 32
|
| 18 |
+
|
| 19 |
+
# 1. Setup SAE
|
| 20 |
+
sae = sae_manager.setup_sae(hook_point, d_model, expansion_factor=2)
|
| 21 |
+
assert hook_point in sae_manager.saes
|
| 22 |
+
assert sae.cfg.d_sae == 64
|
| 23 |
+
|
| 24 |
+
# 2. Mock activations
|
| 25 |
+
activations = torch.randn(100, d_model)
|
| 26 |
+
|
| 27 |
+
# 3. Test reconstruction shape
|
| 28 |
+
reconstructed = sae_manager.reconstruct(hook_point, activations)
|
| 29 |
+
assert reconstructed.shape == activations.shape
|
| 30 |
+
|
| 31 |
+
# 4. Test feature activations shape
|
| 32 |
+
latents = sae_manager.get_feature_activations(hook_point, activations)
|
| 33 |
+
assert latents.shape == (100, 64)
|
| 34 |
+
|
| 35 |
+
# 5. Test anomaly score
|
| 36 |
+
scores = sae_manager.compute_anomaly_score(hook_point, activations)
|
| 37 |
+
assert scores.shape == (100,)
|
| 38 |
+
assert (scores >= 0).all()
|
| 39 |
+
|
| 40 |
+
def test_steering_library():
|
| 41 |
+
lib = SteeringLibrary(d_model=32)
|
| 42 |
+
vec = torch.randn(32)
|
| 43 |
+
lib.add_vector("test_vec", vec)
|
| 44 |
+
|
| 45 |
+
assert "test_vec" in lib.list_vectors()
|
| 46 |
+
assert torch.equal(lib.get_vector("test_vec"), vec)
|
| 47 |
+
|
| 48 |
+
with pytest.raises(ValueError):
|
| 49 |
+
lib.add_vector("wrong_dim", torch.randn(16))
|
| 50 |
+
|
| 51 |
+
def test_caa_generation(model):
|
| 52 |
+
steerer = RTGSteerer(model)
|
| 53 |
+
pos_acts = torch.randn(10, 32) + 1.0
|
| 54 |
+
neg_acts = torch.randn(10, 32) - 1.0
|
| 55 |
+
|
| 56 |
+
vector = steerer.generate_caa_vector(pos_acts, neg_acts)
|
| 57 |
+
assert vector.shape == (32,)
|
| 58 |
+
# Mean difference should be around 2.0 for each dimension
|
| 59 |
+
assert vector.mean() > 0.0
|
| 60 |
+
|
| 61 |
+
def test_steering_hook(model):
|
| 62 |
+
lib = SteeringLibrary(d_model=32)
|
| 63 |
+
vec = torch.ones(32)
|
| 64 |
+
lib.add_vector("boost", vec)
|
| 65 |
+
|
| 66 |
+
steerer = RTGSteerer(model, library=lib)
|
| 67 |
+
hook = steerer.apply_steering_hook("blocks.0.hook_resid_post", "boost", alpha=2.0)
|
| 68 |
+
|
| 69 |
+
input_acts = torch.zeros(1, 5, 32)
|
| 70 |
+
output_acts = hook(input_acts, None)
|
| 71 |
+
|
| 72 |
+
assert torch.allclose(output_acts, torch.ones(1, 5, 32) * 2.0)
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
pytest.main([__file__])
|