sadhumitha-s commited on
Commit
0346604
·
1 Parent(s): 4d7c719

feat: implement SAE manager for latent decomposition and steering library for contrastive activation addition

Browse files
README.md CHANGED
@@ -1,44 +1,89 @@
1
- # DT-Circuits
2
 
3
- A research-grade platform for the mechanistic interpretability of Decision Transformers.
4
 
5
- ## Architecture
6
- - **Data**: PPO Trajectory Harvester for high-quality teacher data.
7
- - **Model**: `HookedDT` - A custom Decision Transformer wrapped in `TransformerLens` for full activation visibility.
8
- - **Interpretability**: Tools for Direct Logit Attribution (DLA), Activation Patching, and Induction Head detection.
9
- - **Dashboard**: Streamlit-based UI for real-time causal interventions.
10
 
11
- ## Quick Start
12
 
13
- ### 1. Install Dependencies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ```bash
15
  pip install -r requirements.txt
16
  ```
17
 
18
- ### 2. Collect Data & Train Mini-DT
19
- ```bash
20
- python scripts/train_dt.py
21
- ```
 
 
22
 
23
- ### 3. Run Interpretation Dashboard
24
- ```bash
25
- streamlit run src/dashboard/app.py
26
- ```
 
27
 
28
  ## Testing
29
- Run the test suite to ensure system integrity:
30
  ```bash
31
- pytest tests/
32
  ```
33
 
34
- ## Components
35
- - `src/data/harvester.py`: Collects trajectories from MiniGrid.
36
- - `src/models/hooked_dt.py`: Hookable transformer implementation.
37
- - `src/interpretability/`:
38
- - `attribution.py`: Direct Logit Attribution logic.
39
- - `patching.py`: Activation patching interface.
40
- - `induction_scan.py`: Automated circuit discovery.
41
-
42
- ## License
43
 
44
- MIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DT-Circuits: Mechanistic Interpretability for Decision Transformers
2
 
3
+ DT-Circuits is a research-grade framework designed for the rigorous mechanistic interpretability of Decision Transformers (DT). By leveraging the TransformerLens paradigm, this platform enables researchers to map internal neural circuits, decompose activations using Sparse Autoencoders, and perform causal interventions on agent decision-making.
4
 
5
+ The primary objective is to move beyond behavioral observation and saliency maps toward a quantitative understanding of how Reward-to-Go, State, and Action tokens are processed within the residual stream.
 
 
 
 
6
 
7
+ ## Core Capabilities
8
 
9
+ ### 1. Circuit Foundation
10
+ - **Hooked-DT Architecture**: A custom Decision Transformer implementation wrapped in TransformerLens, providing full access to internal activations, weights, and the residual stream.
11
+ - **Direct Logit Attribution (DLA)**: Quantitative mapping of individual attention heads and MLP layers to the final action logits.
12
+ - **Induction Head Discovery**: Automated scanning tools to identify heads responsible for temporal pattern recognition and "memory" in RL tasks.
13
+
14
+ ### 2. Causal Interventions
15
+ - **Activation Patching**: Surgical replacement of activations between "clean" and "corrupted" runs to identify bottleneck features and causal paths.
16
+ - **Contrastive Activation Addition (CAA)**: Generation of steering vectors by calculating the mean difference between positive and negative activation sets.
17
+ - **Steering Library**: A persistent library of pre-calculated vectors (e.g., success_vector, exploration_vector) that can be injected at inference time to manipulate agent behavior without retraining.
18
+
19
+ ### 3. Deep Discovery & Safety
20
+ - **Sparse Autoencoder (SAE) Integration**: Tools to train and deploy SAEs on the residual stream, decomposing polysemantic neurons into monosemantic latents.
21
+ - **Mechanistic Anomaly Detection**: Utilizing SAE reconstruction error as a high-fidelity proxy for detecting out-of-distribution (OOD) states.
22
+
23
+ ## Technical Architecture
24
+
25
+ The platform is divided into four primary layers:
26
+ - **Data Layer**: PPO Trajectory Harvester for generating high-quality expert demonstrations in Gymnasium environments (e.g., MiniGrid).
27
+ - **Model Layer**: The HookedDT implementation which maintains compatibility with standard DT architectures while adding hook-based visibility.
28
+ - **Interpretability Layer**: A suite of modules for attribution, patching, SAE management, and steering.
29
+ - **Visualization Layer**: A Streamlit-based dashboard for real-time activation monitoring and interactive steering.
30
+
31
+ ## Getting Started
32
+
33
+ ### Prerequisites
34
+ - Python 3.9+
35
+ - PyTorch 2.x
36
+ - TransformerLens
37
+ - SAE-Lens
38
+
39
+ ### Installation
40
  ```bash
41
  pip install -r requirements.txt
42
  ```
43
 
44
+ ### Basic Workflow
45
+ 1. **Generate Trajectories**:
46
+ Use the harvester to collect teacher data for model training or SAE feature extraction.
47
+ ```bash
48
+ python scripts/train_dt.py
49
+ ```
50
 
51
+ 2. **Run Interpretability Dashboard**:
52
+ Launch the interactive UI to perform real-time patching and steering interventions.
53
+ ```bash
54
+ streamlit run src/dashboard/app.py
55
+ ```
56
 
57
  ## Testing
58
+
59
  ```bash
60
+ PYTHONPATH=. pytest tests/
61
  ```
62
 
63
+ ## Project Structure
 
 
 
 
 
 
 
 
64
 
65
+ ```text
66
+ DT-Circuits/
67
+ ├── scripts/ # Training and harvesting entry points
68
+ │ ├── train_dt.py # Decision Transformer training pipeline
69
+ │ └── train_sae.py # Sparse Autoencoder (SAE) training script
70
+ ├── src/
71
+ │ ├── dashboard/
72
+ │ │ └── app.py # Streamlit-based visualization UI
73
+ │ ├── data/
74
+ │ │ └── harvester.py # PPO-based expert trajectory harvester
75
+ │ ├── interpretability/
76
+ │ │ ├── attribution.py # Direct Logit Attribution (DLA)
77
+ │ │ ├── induction_scan.py # Induction head detection logic
78
+ ��� │ ├── patching.py # Causal activation patching tools
79
+ │ │ ├── sae_manager.py # SAE deployment and anomaly detection
80
+ │ │ └── steering.py # Steering vector generation and injection
81
+ │ ├── models/
82
+ │ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
83
+ │ └── utils/
84
+ ├── tests/ # Unit and integration test suite
85
+ │ ├── test_components.py
86
+ │ └── test_sae_and_steering.py
87
+ ├── config.yaml # Experiment and environment configuration
88
+ └── requirements.txt # Environment dependencies
89
+ ```
src/__init__.py ADDED
File without changes
src/dashboard/__init__.py ADDED
File without changes
src/interpretability/__init__.py ADDED
File without changes
src/interpretability/sae_manager.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+ from sae_lens import StandardSAE, StandardSAEConfig
6
+ from jaxtyping import Float
7
+
8
+ class SAEManager:
9
+ """
10
+ Research-grade manager for Sparse Autoencoders (SAEs) integrated with Decision Transformers.
11
+ Handles training, decomposition into monosemantic latents, and mechanistic anomaly detection.
12
+ """
13
+ def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
14
+ self.model = model
15
+ self.sae_dir = sae_dir
16
+ self.saes: Dict[str, StandardSAE] = {}
17
+ os.makedirs(sae_dir, exist_ok=True)
18
+
19
+ def setup_sae(
20
+ self,
21
+ hook_point: str,
22
+ d_model: int,
23
+ expansion_factor: int = 8,
24
+ ) -> StandardSAE:
25
+ """
26
+ Initializes an SAE for a specific hook point in the transformer.
27
+ """
28
+ cfg = StandardSAEConfig(
29
+ d_in=d_model,
30
+ d_sae=d_model * expansion_factor,
31
+ device=str(next(self.model.parameters()).device)
32
+ )
33
+ sae = StandardSAE(cfg)
34
+ self.saes[hook_point] = sae
35
+ return sae
36
+
37
+ def train_on_trajectories(
38
+ self,
39
+ hook_point: str,
40
+ activations: Float[torch.Tensor, "n_samples d_model"],
41
+ l1_coefficient: float = 0.0001,
42
+ batch_size: int = 1024,
43
+ epochs: int = 10,
44
+ ):
45
+ """
46
+ Trains the SAE on collected trajectory activations.
47
+ """
48
+ if hook_point not in self.saes:
49
+ self.setup_sae(hook_point, activations.shape[-1])
50
+
51
+ sae = self.saes[hook_point]
52
+ optimizer = torch.optim.Adam(sae.parameters(), lr=0.0004)
53
+
54
+ sae.train()
55
+ n_samples = activations.shape[0]
56
+
57
+ for epoch in range(epochs):
58
+ permutation = torch.randperm(n_samples)
59
+ epoch_loss = 0
60
+ for i in range(0, n_samples, batch_size):
61
+ indices = permutation[i:i+batch_size]
62
+ batch_acts = activations[indices].to(sae.device)
63
+
64
+ optimizer.zero_grad()
65
+
66
+ # Manual forward pass for training
67
+ feature_acts = sae.encode(batch_acts)
68
+ sae_out = sae.decode(feature_acts)
69
+
70
+ mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
71
+ l1_loss = l1_coefficient * feature_acts.abs().sum()
72
+ loss = mse_loss + l1_loss
73
+
74
+ loss.backward()
75
+ optimizer.step()
76
+ epoch_loss += loss.item()
77
+
78
+ print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss / (n_samples / batch_size):.4f}")
79
+
80
+ def get_feature_activations(
81
+ self,
82
+ hook_point: str,
83
+ activations: Float[torch.Tensor, "... d_model"]
84
+ ) -> Float[torch.Tensor, "... d_sae"]:
85
+ """
86
+ Decomposes activations into monosemantic features.
87
+ """
88
+ if hook_point not in self.saes:
89
+ raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
90
+
91
+ sae = self.saes[hook_point]
92
+ sae.eval()
93
+ with torch.no_grad():
94
+ feature_acts = sae.encode(activations.to(sae.device))
95
+ return feature_acts
96
+
97
+ def reconstruct(
98
+ self,
99
+ hook_point: str,
100
+ activations: Float[torch.Tensor, "... d_model"]
101
+ ) -> Float[torch.Tensor, "... d_model"]:
102
+ """
103
+ Reconstructs the original activations using the SAE.
104
+ """
105
+ if hook_point not in self.saes:
106
+ raise ValueError(f"SAE for {hook_point} not found.")
107
+
108
+ sae = self.saes[hook_point]
109
+ sae.eval()
110
+ with torch.no_grad():
111
+ feature_acts = sae.encode(activations.to(sae.device))
112
+ sae_out = sae.decode(feature_acts)
113
+ return sae_out
114
+
115
+ def compute_anomaly_score(
116
+ self,
117
+ hook_point: str,
118
+ activations: Float[torch.Tensor, "... d_model"]
119
+ ) -> Float[torch.Tensor, "..."]:
120
+ """
121
+ Calculates reconstruction error as a proxy for mechanistic anomaly detection.
122
+ Formula: ||x - x_hat|| / ||x||
123
+ """
124
+ if hook_point not in self.saes:
125
+ raise ValueError(f"SAE for {hook_point} not found.")
126
+
127
+ sae = self.saes[hook_point]
128
+ sae.eval()
129
+ with torch.no_grad():
130
+ x = activations.to(sae.device)
131
+ feature_acts = sae.encode(x)
132
+ x_hat = sae.decode(feature_acts)
133
+
134
+ error = torch.norm(x - x_hat, dim=-1) / (torch.norm(x, dim=-1) + 1e-8)
135
+ return error
136
+
137
+ def save_all_saes(self):
138
+ for hook, sae in self.saes.items():
139
+ path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
140
+ torch.save({
141
+ 'state_dict': sae.state_dict(),
142
+ 'cfg': sae.cfg
143
+ }, path)
144
+ print(f"Saved SAE for {hook} to {path}")
145
+
146
+ def load_sae(self, hook_point: str):
147
+ path = os.path.join(self.sae_dir, f"{hook_point.replace('.', '_')}_sae.pt")
148
+ if not os.path.exists(path):
149
+ raise FileNotFoundError(f"No saved SAE found at {path}")
150
+
151
+ checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device))
152
+ sae = StandardSAE(checkpoint['cfg'])
153
+ sae.load_state_dict(checkpoint['state_dict'])
154
+ self.saes[hook_point] = sae
155
+ return sae
src/interpretability/steering.py CHANGED
@@ -1,43 +1,78 @@
1
  import torch
2
  import torch.nn as nn
3
- from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class RTGSteerer:
6
  """
7
- Enables 'Behavioral Steering' by manipulating Reward-to-Go (RTG) tokens.
 
8
  """
9
- def __init__(self, model):
10
  self.model = model
 
11
 
12
- def steer(
13
  self,
14
- states: torch.Tensor,
15
- actions: torch.Tensor,
16
  base_rtg: torch.Tensor,
17
- steering_vector: torch.Tensor,
 
18
  alpha: float = 1.0
19
- ):
20
  """
21
  Adds a steering vector to the RTG embeddings.
22
- RTG_new = RTG_base + alpha * steering_vector
23
  """
24
- # Embed base RTG
 
25
  with torch.no_grad():
26
  rtg_emb = self.model.embed_return(base_rtg)
27
-
28
- # Apply steering
29
- steered_rtg_emb = rtg_emb + alpha * steering_vector
30
-
31
- # Hook the model to use the steered RTG
32
- # This requires a slightly more complex hook in HookedDT
33
- # For now, we returns the steered embedding to be used in a custom forward pass
34
- return steered_rtg_emb
 
 
 
 
 
 
 
 
 
 
35
 
36
- def find_success_vector(self, high_reward_cache, low_reward_cache):
37
  """
38
- Identifies the 'Success Vector' by comparing high vs low reward activations.
39
- Vector = Mean(High Reward Residual) - Mean(Low Reward Residual)
40
  """
41
- high_res = high_reward_cache["blocks.0.hook_resid_post"].mean(dim=(0, 1))
42
- low_res = low_reward_cache["blocks.0.hook_resid_post"].mean(dim=(0, 1))
43
- return high_res - low_res
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from typing import Dict, List, Optional
4
+ class SteeringLibrary:
5
+ """
6
+ A persistent library of pre-calculated steering vectors (CAA).
7
+ Includes vectors for exploration, safety, and goal-directedness.
8
+ """
9
+ def __init__(self, d_model: int):
10
+ self.d_model = d_model
11
+ self.vectors: Dict[str, torch.Tensor] = {}
12
+
13
+ def add_vector(self, name: str, vector: torch.Tensor):
14
+ if vector.shape[-1] != self.d_model:
15
+ raise ValueError(f"Vector dimension {vector.shape[-1]} does not match d_model {self.d_model}")
16
+ self.vectors[name] = vector
17
+
18
+ def get_vector(self, name: str) -> torch.Tensor:
19
+ if name not in self.vectors:
20
+ raise KeyError(f"Vector '{name}' not found in library.")
21
+ return self.vectors[name]
22
+
23
+ def list_vectors(self) -> List[str]:
24
+ return list(self.vectors.keys())
25
 
26
  class RTGSteerer:
27
  """
28
+ Enables 'Behavioral Steering' by manipulating Reward-to-Go (RTG) tokens or internal activations.
29
+ Supports Contrastive Activation Addition (CAA).
30
  """
31
+ def __init__(self, model, library: Optional[SteeringLibrary] = None):
32
  self.model = model
33
+ self.library = library or SteeringLibrary(model.cfg.d_model)
34
 
35
+ def steer_rtg(
36
  self,
 
 
37
  base_rtg: torch.Tensor,
38
+ vector_name: Optional[str] = None,
39
+ custom_vector: Optional[torch.Tensor] = None,
40
  alpha: float = 1.0
41
+ ) -> torch.Tensor:
42
  """
43
  Adds a steering vector to the RTG embeddings.
 
44
  """
45
+ vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name)
46
+
47
  with torch.no_grad():
48
  rtg_emb = self.model.embed_return(base_rtg)
49
+ return rtg_emb + alpha * vector
50
+
51
+ def generate_caa_vector(
52
+ self,
53
+ positive_activations: torch.Tensor,
54
+ negative_activations: torch.Tensor,
55
+ method: str = "mean_diff"
56
+ ) -> torch.Tensor:
57
+ """
58
+ Generates a steering vector using Contrastive Activation Addition.
59
+ 'mean_diff' calculates the difference between the means of positive and negative sets.
60
+ """
61
+ if method == "mean_diff":
62
+ pos_mean = positive_activations.mean(dim=0)
63
+ neg_mean = negative_activations.mean(dim=0)
64
+ return pos_mean - neg_mean
65
+ else:
66
+ raise NotImplementedError(f"Method {method} not implemented.")
67
 
68
+ def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0):
69
  """
70
+ Returns a HookedTransformer compatible hook function that applies steering.
 
71
  """
72
+ vector = self.library.get_vector(vector_name)
73
+
74
+ def steering_hook(activations, hook):
75
+ # activations: [batch, pos, d_model]
76
+ return activations + alpha * vector
77
+
78
+ return steering_hook
src/models/__init__.py ADDED
File without changes
src/utils/__init__.py ADDED
File without changes
tests/test_sae_and_steering.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ from src.models.hooked_dt import HookedDT
4
+ from src.interpretability.sae_manager import SAEManager
5
+ from src.interpretability.steering import RTGSteerer, SteeringLibrary
6
+
7
+ @pytest.fixture
8
+ def model():
9
+ return HookedDT.from_config(state_dim=10, action_dim=5, n_layers=1, n_heads=2, d_model=32)
10
+
11
+ @pytest.fixture
12
+ def sae_manager(model):
13
+ return SAEManager(model)
14
+
15
+ def test_sae_lifecycle(sae_manager):
16
+ hook_point = "blocks.0.hook_resid_post"
17
+ d_model = 32
18
+
19
+ # 1. Setup SAE
20
+ sae = sae_manager.setup_sae(hook_point, d_model, expansion_factor=2)
21
+ assert hook_point in sae_manager.saes
22
+ assert sae.cfg.d_sae == 64
23
+
24
+ # 2. Mock activations
25
+ activations = torch.randn(100, d_model)
26
+
27
+ # 3. Test reconstruction shape
28
+ reconstructed = sae_manager.reconstruct(hook_point, activations)
29
+ assert reconstructed.shape == activations.shape
30
+
31
+ # 4. Test feature activations shape
32
+ latents = sae_manager.get_feature_activations(hook_point, activations)
33
+ assert latents.shape == (100, 64)
34
+
35
+ # 5. Test anomaly score
36
+ scores = sae_manager.compute_anomaly_score(hook_point, activations)
37
+ assert scores.shape == (100,)
38
+ assert (scores >= 0).all()
39
+
40
+ def test_steering_library():
41
+ lib = SteeringLibrary(d_model=32)
42
+ vec = torch.randn(32)
43
+ lib.add_vector("test_vec", vec)
44
+
45
+ assert "test_vec" in lib.list_vectors()
46
+ assert torch.equal(lib.get_vector("test_vec"), vec)
47
+
48
+ with pytest.raises(ValueError):
49
+ lib.add_vector("wrong_dim", torch.randn(16))
50
+
51
+ def test_caa_generation(model):
52
+ steerer = RTGSteerer(model)
53
+ pos_acts = torch.randn(10, 32) + 1.0
54
+ neg_acts = torch.randn(10, 32) - 1.0
55
+
56
+ vector = steerer.generate_caa_vector(pos_acts, neg_acts)
57
+ assert vector.shape == (32,)
58
+ # Mean difference should be around 2.0 for each dimension
59
+ assert vector.mean() > 0.0
60
+
61
+ def test_steering_hook(model):
62
+ lib = SteeringLibrary(d_model=32)
63
+ vec = torch.ones(32)
64
+ lib.add_vector("boost", vec)
65
+
66
+ steerer = RTGSteerer(model, library=lib)
67
+ hook = steerer.apply_steering_hook("blocks.0.hook_resid_post", "boost", alpha=2.0)
68
+
69
+ input_acts = torch.zeros(1, 5, 32)
70
+ output_acts = hook(input_acts, None)
71
+
72
+ assert torch.allclose(output_acts, torch.ones(1, 5, 32) * 2.0)
73
+
74
+ if __name__ == "__main__":
75
+ pytest.main([__file__])