File size: 3,907 Bytes
b7ddfc6
 
e2614dc
b7ddfc6
 
 
 
 
 
 
 
 
e2614dc
b7ddfc6
 
 
 
 
 
 
 
 
 
 
e2614dc
 
b7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2614dc
b7ddfc6
 
 
 
 
 
 
 
 
e2614dc
b7ddfc6
 
e2614dc
b7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2614dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import sys
from pathlib import Path
import torch
from sae_lens import TopKSAEConfig, TopKSAE

# Add project root to path for absolute imports
root_path = str(Path(__file__).resolve().parent.parent)
if root_path not in sys.path:
    sys.path.append(root_path)

import random
import numpy as np
from src.models.hooked_dt import HookedDT
from src.interpretability.sae_manager import SAEManager
from src.config import cfg

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_sae():
    # 0. Set seed for reproducibility
    set_seed(cfg.train.seed)

    # 1. Load Trajectories to get dimensions
    traj_path = "data/trajectories.pt"
    if not Path(traj_path).exists():
        print(f"Error: {traj_path} not found. Please run scripts/train_dt.py first.")
        return
    
    trajectories = torch.load(traj_path, weights_only=False)
    print(f"Loaded {len(trajectories)} trajectories.")

    # 2. Initialize Model
    state_dim = trajectories[0]["observations"].shape[1]
    action_dim = cfg.model.action_dim
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model = HookedDT.from_config(
        state_dim=state_dim, 
        action_dim=action_dim,
        n_layers=cfg.model.n_layers,
        n_heads=cfg.model.n_heads,
        d_model=cfg.model.d_model
    )
    model.to(device)
    
    # Check for trained DT checkpoint
    checkpoint_path = "models/mini_dt.pt"
    if Path(checkpoint_path).exists():
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded DT weights from {checkpoint_path}")
    else:
        print(f"Warning: {checkpoint_path} not found. Training SAE on random weights.")

    # 3. & 4. Train SAEs for ALL layers
    manager = SAEManager(model, sae_dir="artifacts/saes")
    
    for layer in range(model.cfg.n_layers):
        hook_point = f"blocks.{layer}.hook_resid_post"
        all_activations = []
        
        print(f"\n--- Processing Layer {layer} ({hook_point}) ---")
        
        # Extract Activations
        model.eval()
        print(f"Extracting activations...")

        # Number of trajectories from config
        num_trajs_to_use = min(len(trajectories), cfg.sae.num_episodes)
        
        with torch.no_grad():
            for traj in trajectories[:num_trajs_to_use]:
                states = torch.from_numpy(traj["observations"]).float().to(device).unsqueeze(0)
                actions = torch.from_numpy(traj["actions"]).long().to(device)
                actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
                returns = torch.from_numpy(traj["rewards"]).float().to(device).unsqueeze(0).unsqueeze(-1)

                _, cache = model(states, actions_one_hot, returns, return_cache=True)
                all_activations.append(cache[hook_point].squeeze(0).cpu())
        
        activations = torch.cat(all_activations, dim=0)
        print(f"Collected {activations.shape[0]} activation vectors.")

        # Setup and Train
        print(f"Starting TopK SAE training...")
        manager.setup_sae(
            hook_point=hook_point,
            d_model=cfg.model.d_model,
            architecture="topk",
            k=cfg.sae.k
        )
        
        manager.train_on_trajectories(
            hook_point=hook_point,
            activations=activations,
            epochs=cfg.sae.epochs,
            batch_size=cfg.sae.batch_size
        )

    # Save all SAEs once training is complete for all layers
    manager.save_all_saes()

    print(f"\nSAE Training Complete for all {model.cfg.n_layers} layers. Results saved to artifacts/saes/")

if __name__ == "__main__":
    train_sae()