Spaces:
Running
Running
refactor: implement centralized configuration, upgrade SAE training to multi-layer TopK, and optimize dashboard attribution UX
b7ddfc6 | import sys | |
| from pathlib import Path | |
| import torch | |
| from sae_lens import TopKSAEConfig, TopKSAE | |
| # Add project root to path for absolute imports | |
| root_path = str(Path(__file__).resolve().parent.parent) | |
| if root_path not in sys.path: | |
| sys.path.append(root_path) | |
| import random | |
| import numpy as np | |
| from src.models.hooked_dt import HookedDT | |
| from src.interpretability.sae_manager import SAEManager | |
| from src.config import cfg | |
| def set_seed(seed: int = 42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def train_sae(): | |
| # 0. Set seed for reproducibility | |
| set_seed(cfg.train.seed) | |
| # 1. Load Trajectories to get dimensions | |
| traj_path = "data/trajectories.pt" | |
| if not Path(traj_path).exists(): | |
| print(f"Error: {traj_path} not found. Please run scripts/train_dt.py first.") | |
| return | |
| trajectories = torch.load(traj_path, weights_only=False) | |
| print(f"Loaded {len(trajectories)} trajectories.") | |
| # 2. Initialize Model | |
| state_dim = trajectories[0]["observations"].shape[1] | |
| action_dim = cfg.model.action_dim | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = HookedDT.from_config( | |
| state_dim=state_dim, | |
| action_dim=action_dim, | |
| n_layers=cfg.model.n_layers, | |
| n_heads=cfg.model.n_heads, | |
| d_model=cfg.model.d_model | |
| ) | |
| model.to(device) | |
| # Check for trained DT checkpoint | |
| checkpoint_path = "models/mini_dt.pt" | |
| if Path(checkpoint_path).exists(): | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) | |
| print(f"Loaded DT weights from {checkpoint_path}") | |
| else: | |
| print(f"Warning: {checkpoint_path} not found. Training SAE on random weights.") | |
| # 3. & 4. Train SAEs for ALL layers | |
| manager = SAEManager(model, sae_dir="artifacts/saes") | |
| for layer in range(model.cfg.n_layers): | |
| hook_point = f"blocks.{layer}.hook_resid_post" | |
| all_activations = [] | |
| print(f"\n--- Processing Layer {layer} ({hook_point}) ---") | |
| # Extract Activations | |
| model.eval() | |
| print(f"Extracting activations...") | |
| # Number of trajectories from config | |
| num_trajs_to_use = min(len(trajectories), cfg.sae.num_episodes) | |
| with torch.no_grad(): | |
| for traj in trajectories[:num_trajs_to_use]: | |
| states = torch.from_numpy(traj["observations"]).float().to(device).unsqueeze(0) | |
| actions = torch.from_numpy(traj["actions"]).long().to(device) | |
| actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0) | |
| returns = torch.from_numpy(traj["rewards"]).float().to(device).unsqueeze(0).unsqueeze(-1) | |
| _, cache = model(states, actions_one_hot, returns, return_cache=True) | |
| all_activations.append(cache[hook_point].squeeze(0).cpu()) | |
| activations = torch.cat(all_activations, dim=0) | |
| print(f"Collected {activations.shape[0]} activation vectors.") | |
| # Setup and Train | |
| print(f"Starting TopK SAE training...") | |
| manager.setup_sae( | |
| hook_point=hook_point, | |
| d_model=cfg.model.d_model, | |
| architecture="topk", | |
| k=cfg.sae.k | |
| ) | |
| manager.train_on_trajectories( | |
| hook_point=hook_point, | |
| activations=activations, | |
| epochs=cfg.sae.epochs, | |
| batch_size=cfg.sae.batch_size | |
| ) | |
| # Save all SAEs once training is complete for all layers | |
| manager.save_all_saes() | |
| print(f"\nSAE Training Complete for all {model.cfg.n_layers} layers. Results saved to artifacts/saes/") | |
| if __name__ == "__main__": | |
| train_sae() | |