import streamlit as st import torch import os import sys import time import json from pathlib import Path # Add project root to path for absolute imports root_path = str(Path(__file__).resolve().parent.parent.parent) if root_path not in sys.path: sys.path.append(root_path) import numpy as np import matplotlib.pyplot as plt import gymnasium as gym from minigrid.wrappers import FlatObsWrapper from src.models.hooked_dt import HookedDT from src.interpretability.attribution import LogitAttributionEngine from src.interpretability.patching import ActivationPatcher from src.interpretability.sae_manager import SAEManager st.set_page_config(page_title="DT-Explorer", layout="wide") st.title("DT-Explorer: Mechanistic Interpretability for DT") # Sidebar for loading model and data st.sidebar.header("Data & Model") # List available models in a secure dropdown to prevent Path Traversal models_dir = Path("models") available_models = [] if models_dir.exists(): available_models = [str(p) for p in models_dir.glob("*.pt")] if not available_models: available_models = ["models/mini_dt.pt"] model_path = st.sidebar.selectbox("Select Model Path", sorted(available_models)) # List available datasets in a secure dropdown to prevent Path Traversal data_dir = Path("data") available_data = [] if data_dir.exists(): available_data = [str(p) for p in data_dir.glob("*.pt")] if not available_data: available_data = ["data/trajectories.pt"] data_path = st.sidebar.selectbox("Select Trajectory Path", sorted(available_data)) # Validation check to guarantee path safety (Defense-in-depth) def is_safe_path(base_dir, path): base_abs = Path(base_dir).resolve() path_abs = Path(path).resolve() return path_abs.parts[:len(base_abs.parts)] == base_abs.parts @st.cache_data def get_data(path): if not is_safe_path("data", path): st.sidebar.error("Access Denied: Invalid trajectory path.") st.stop() if not os.path.exists(path): st.sidebar.warning(f"Data not found at {path}. Please run training script.") return None # Use weights_only=False because trajectories contain numpy arrays return torch.load(path, map_location="cpu", weights_only=False) @st.cache_resource def get_model(path, state_dim): if not is_safe_path("models", path): st.sidebar.error("Access Denied: Invalid model path.") st.stop() if not os.path.exists(path): st.sidebar.warning(f"Model not found at {path}. Using random init for demo.") return HookedDT.from_config(state_dim=state_dim, action_dim=7) model = HookedDT.from_config(state_dim=state_dim, action_dim=7) try: # Load state dict (safe for weights_only=True) model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True)) model.eval() except Exception as e: st.sidebar.error(f"Error loading model: {e}") return model # 1. Load Data First trajectories = get_data(data_path) if trajectories is None: st.error("No real data available. Please run `python scripts/train_dt.py` first.") st.stop() # 2. Determine State Dim state_dim = trajectories[0]["observations"].shape[1] # 3. Load Model with Correct Dim model = get_model(model_path, state_dim) # Select a trajectory and token for analysis traj_idx = st.sidebar.number_input("Select Trajectory", 0, len(trajectories)-1, 0) traj = trajectories[traj_idx] tab1, tab2, tab3, tab4 = st.tabs([ "Circuit Mapping (DLA)", "Causal Intervention (Patching)", "SAE Latents", "Brain Surgeon & Circuit Explorer" ]) with tab1: st.header("Direct Logit Attribution (DLA)") st.write("Visualizing which heads contribute most to the predicted action.") # Run automatically for better UX when changing trajectories states = torch.from_numpy(traj["observations"]).float().unsqueeze(0) actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0) returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1) preds, cache = model(states, actions, returns, return_cache=True) target_action = preds[0, -1].argmax().item() engine = LogitAttributionEngine(model) # Use token index -2 to target the state token which predicts the action dla_results = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2) fig, ax = plt.subplots() im = ax.imshow(dla_results.detach().cpu().numpy(), cmap="RdBu_r", aspect='auto') plt.colorbar(im) ax.set_xlabel("Head") ax.set_ylabel("Layer") st.pyplot(fig) st.write(f"Analyzing Attribution for Action: {target_action} (at State token)") with tab2: st.header("Activation Patching") st.write("Quantifying causal importance by patching corrupted activations.") # Pre-calculate DLA for better UI feedback and dropdown probabilities states = torch.from_numpy(traj["observations"]).float().unsqueeze(0) actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0) returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1) with torch.no_grad(): preds, cache = model(states, actions, returns, return_cache=True) target_action = preds[0, -1].argmax().item() engine = LogitAttributionEngine(model) # Calculate DLA to show scores in dropdowns dla_results = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2) # Use format_func to show probabilities/attribution in the dropdown options layer_options = [f"Layer {i} (Avg DLA: {dla_results[i].mean():.4f})" for i in range(model.cfg.n_layers)] layer_idx = st.selectbox("Select Layer", range(model.cfg.n_layers), format_func=lambda x: layer_options[x]) head_options = [f"Head {j} (DLA: {dla_results[layer_idx, j]:.4f})" for j in range(model.cfg.n_heads)] head_idx = st.selectbox("Select Head", range(model.cfg.n_heads), format_func=lambda x: head_options[x]) if st.button("Calculate Probability Drop"): patcher = ActivationPatcher(model) # Simple corruption: zero out the state token we are patching corrupted_states = states.clone() corrupted_states[0, -1, :] = 0.0 clean_logits = preds _, corrupted_cache = model(corrupted_states, actions, returns, return_cache=True) # Patch at token index -2 (State token) patched_logits = patcher.patch_head( {"states": states, "actions": actions, "returns_to_go": returns}, corrupted_cache, layer_idx, head_idx, target_token_index=-2 ) drop = patcher.calculate_probability_drop( torch.softmax(clean_logits, dim=-1), torch.softmax(patched_logits, dim=-1), target_action ) st.metric("Logit Prob Drop", f"{drop:.4f}") if drop > 0.01: st.success(f"Head {layer_idx}.{head_idx} has causal impact ({drop:.4f}) on this decision.") else: st.info("Low causal impact observed for this head.") with tab3: st.header("High-Fidelity Latent Discovery") st.write("Exploring monosemantic features via Sparse Autoencoders (TopK SAEs).") sae_manager = SAEManager(model) hook_points = [f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)] selected_hook = st.selectbox("Select Hook Point", hook_points) try: sae = sae_manager.load_sae(selected_hook) st.success(f"Loaded SAE for {selected_hook}") # Visualize latents for current state states = torch.from_numpy(traj["observations"]).float().unsqueeze(0) actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0) returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1) _, cache = model(states, actions, returns, return_cache=True) activations = cache[selected_hook][:, -2, :] # State token latents latents = sae.encode(activations) top_values, top_indices = torch.topk(latents[0], k=10) st.subheader("Top-10 Active Latents") cols = st.columns(5) for i in range(10): with cols[i % 5]: st.metric(f"Latent #{top_indices[i].item()}", f"{top_values[i].item():.4f}") reconstruction_error = sae_manager.compute_anomaly_score(selected_hook, activations) st.metric("Reconstruction Error (L2 Norm)", f"{reconstruction_error.item():.4f}") except FileNotFoundError: st.warning(f"No trained SAE found for {selected_hook} in `artifacts/saes/`.") st.info("Please run `python scripts/train_sae.py` to generate latent features.") with tab4: st.header("Brain Surgeon & Circuit Explorer") st.write("Perform real-time node and path ablations to visualize and audit the agent's internal reasoning pathways.") from src.interpretability.circuit_surgeon import CircuitSurgeon from src.interpretability.neuronpedia import NeuronpediaExporter # Initialize CircuitSurgeon on the active model surgeon = CircuitSurgeon(model) n_layers = model.cfg.n_layers n_heads = model.cfg.n_heads # Dynamic nodes list all_nodes = [] for l in range(n_layers): for h in range(n_heads): all_nodes.append(f"L{l}H{h}") all_nodes.append(f"L{l}MLP") # Dynamic edges list all_edges = [] for l1 in range(n_layers): # Within layer attention to MLP for h in range(n_heads): all_edges.append(f"L{l1}H{h} -> L{l1}MLP") # Across layers for l2 in range(l1 + 1, n_layers): for h1 in range(n_heads): for h2 in range(n_heads): all_edges.append(f"L{l1}H{h1} -> L{l2}H{h2}") all_edges.append(f"L{l1}H{h1} -> L{l2}MLP") for h2 in range(n_heads): all_edges.append(f"L{l1}MLP -> L{l2}H{h2}") all_edges.append(f"L{l1}MLP -> L{l2}MLP") col1, col2 = st.columns([1, 2]) with col1: st.subheader("Surgical Controls") ablated_nodes_selected = st.multiselect( "Ablate Nodes", options=all_nodes, help="Zero out all activations exiting these specific components." ) ablated_edges_selected = st.multiselect( "Ablate Communication Paths (Edges)", options=all_edges, help="Sever the communication channel between two layers or components." ) # Register currently selected ablations to CircuitSurgeon for node in ablated_nodes_selected: surgeon.add_node_ablation(node) for edge in ablated_edges_selected: parts = edge.split(" -> ") surgeon.add_edge_ablation(parts[0], parts[1]) # Target reward-to-go slider target_rtg = st.slider("Goal Reward-to-Go", 0.1, 1.5, 0.9, 0.05) run_simulation = st.button("Run Live MiniGrid Simulation") with col2: st.subheader("Interactive Circuit Blueprint") st.write("Visualized via Cytoscape.js. Severed components are highlighted in vibrant red/dashed styling.") # Build elements for Cytoscape.js cy_nodes = [] cy_edges = [] # Position layers horizontally for l in range(n_layers): x_pos = 100 + l * 250 for h in range(n_heads): node_id = f"L{l}H{h}" y_pos = 50 + h * 90 is_ablated = node_id in ablated_nodes_selected cy_nodes.append({ "data": {"id": node_id, "label": node_id, "type": "head", "ablated": is_ablated}, "position": {"x": x_pos, "y": y_pos} }) mlp_id = f"L{l}MLP" y_pos = 50 + n_heads * 90 is_ablated = mlp_id in ablated_nodes_selected cy_nodes.append({ "data": {"id": mlp_id, "label": mlp_id, "type": "mlp", "ablated": is_ablated}, "position": {"x": x_pos, "y": y_pos} }) for edge in all_edges: parts = edge.split(" -> ") src, dest = parts[0], parts[1] is_edge_ablated = edge in ablated_edges_selected is_endpoint_ablated = src in ablated_nodes_selected or dest in ablated_nodes_selected cy_edges.append({ "data": { "id": f"{src}_{dest}", "source": src, "target": dest, "ablated": is_edge_ablated or is_endpoint_ablated } }) cy_elements_json = json.dumps(cy_nodes + cy_edges) cytoscape_html = f"""
""" st.iframe(cytoscape_html, height=420) # 5. Live Simulation execution block if run_simulation: st.subheader("Live Agent Behavioral Audit") status_box = st.empty() img_box = st.empty() try: # Recreate exact MiniGrid env setup from harvester env = FlatObsWrapper(gym.make("MiniGrid-Empty-8x8-v0", render_mode="rgb_array")) obs, _ = env.reset(seed=42) states_history = [obs] actions_history = [np.zeros(7)] rewards_history = [target_rtg] max_len = model.max_length total_reward = 0.0 steps = 0 while steps < 30: # Format histories into tensors states_t = torch.tensor(np.array(states_history[-max_len:]), dtype=torch.float32).unsqueeze(0) actions_t = torch.tensor(np.array(actions_history[-max_len:]), dtype=torch.float32).unsqueeze(0) returns_t = torch.tensor(np.array(rewards_history[-max_len:]), dtype=torch.float32).unsqueeze(0).unsqueeze(-1) # Execute DT with ablated circuit surgeon forward preds = surgeon.compute_ablated_forward(states_t, actions_t, returns_t) act = preds[0, -1].argmax().item() next_obs, reward, done, truncated, _ = env.step(act) total_reward += reward steps += 1 # Render current grid step frame = env.render() img_box.image(frame, caption=f"Step {steps} | Action {act}", width=320) status_box.info(f"Stepping Agent... Current Step: {steps}/30 | Cumulative Reward: {total_reward:.4f}") # Update histories states_history.append(next_obs) act_one_hot = np.zeros(7) act_one_hot[act] = 1.0 actions_history.append(act_one_hot) rewards_history.append(rewards_history[-1] - reward) time.sleep(0.12) if done or truncated: break env.close() if total_reward > 0: st.success(f"Execution complete. Agent successfully reached the goal in {steps} steps! Cumulative Reward: {total_reward:.4f}") else: st.warning("Agent failed to reach the goal under this ablated circuit/communication configuration.") except Exception as e: st.error(f"Failed to run environment simulation: {str(e)}") # 6. Neuronpedia Export Section st.markdown("---") st.subheader("Neuronpedia Export Hub") st.write("Publish discovered circuits, active heads, and ablated configurations to public peer-review.") np_col1, np_col2 = st.columns(2) with np_col1: np_key = st.text_input("Neuronpedia Access Key (Optional)", type="password", help="If provided, uploads directly. Otherwise, saves circuit payload in artifacts/.") with np_col2: export_btn = st.button("Publish Discovered Circuit Blueprint") if export_btn: exporter = NeuronpediaExporter(api_key=np_key if np_key else None) manifest = { "active_heads": [n for n in all_nodes if n not in ablated_nodes_selected], "pruned_count": len(ablated_nodes_selected), "initial_perf": 1.0, "final_perf": 0.0 if len(ablated_nodes_selected) > 0 else 1.0, "ablated_paths": list(ablated_edges_selected), "ablated_nodes": list(ablated_nodes_selected), "state_dim": state_dim, "action_dim": 7, "n_layers": n_layers, "n_heads": n_heads } res = exporter.export_circuit(model_id="mini_dt", circuit_manifest=manifest) if "local" in res["status"]: st.success(res["message"]) st.json(res["payload"]) elif "success" in res["status"]: st.success(res["message"]) st.markdown(f"[View Live Uploaded Circuit]({res['url']})") else: st.error(res["message"])