File size: 3,960 Bytes
33a0021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import json
import requests
from typing import Dict, List, Any, Optional

class NeuronpediaExporter:
    """
    Handles exporting discovered circuits, SAE feature activations, and natural language
    labels to the Neuronpedia platform for public sharing and peer review.
    """
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key
        self.base_url = "https://www.neuronpedia.org/api"

    def format_circuit_payload(self, model_id: str, circuit_manifest: Dict[str, Any]) -> Dict[str, Any]:
        """
        Formats a circuit manifest (active heads, ablated paths, performance scores)
        into a standardized Neuronpedia-compatible schema.
        """
        payload = {
            "model_id": model_id,
            "source": "DT-Circuits",
            "schema_version": "1.0.0",
            "circuit": {
                "active_heads": circuit_manifest.get("active_heads", []),
                "pruned_count": circuit_manifest.get("pruned_count", 0),
                "initial_performance": circuit_manifest.get("initial_perf", 0.0),
                "final_performance": circuit_manifest.get("final_perf", 0.0),
                "ablated_paths": list(circuit_manifest.get("ablated_paths", [])),
                "ablated_nodes": list(circuit_manifest.get("ablated_nodes", []))
            },
            "metadata": {
                "state_dim": circuit_manifest.get("state_dim"),
                "action_dim": circuit_manifest.get("action_dim"),
                "n_layers": circuit_manifest.get("n_layers"),
                "n_heads": circuit_manifest.get("n_heads"),
            }
        }
        return payload

    def export_circuit(
        self, 
        model_id: str, 
        circuit_manifest: Dict[str, Any], 
        local_path: str = "artifacts/neuronpedia_export.json"
    ) -> Dict[str, Any]:
        """
        Exports the discovered circuit. If an API key is present, uploads it
        directly to Neuronpedia; otherwise, serializes it to a local JSON file.
        """
        payload = self.format_circuit_payload(model_id, circuit_manifest)
        
        # Ensure target directory exists
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        
        # Save locally as backup or primary artifact
        with open(local_path, "w") as f:
            json.dump(payload, f, indent=4)

        if not self.api_key:
            return {
                "status": "success_local",
                "message": f"Circuit exported locally to {local_path}. Provide a Neuronpedia API Key for online sharing.",
                "path": local_path,
                "payload": payload
            }

        headers = {
            "Content-Type": "application/json",
            "X-Api-Key": self.api_key
        }
        
        try:
            response = requests.post(
                f"{self.base_url}/circuits/upload", 
                json=payload, 
                headers=headers,
                timeout=10
            )
            if response.status_code == 200 or response.status_code == 201:
                return {
                    "status": "success_api",
                    "message": "Circuit successfully uploaded to Neuronpedia!",
                    "url": response.json().get("url", "https://www.neuronpedia.org"),
                    "payload": payload
                }
            else:
                return {
                    "status": "error_api",
                    "message": f"Neuronpedia API rejected upload: {response.status_code} - {response.text}",
                    "path": local_path,
                    "payload": payload
                }
        except Exception as e:
            return {
                "status": "error_exception",
                "message": f"Network exception during upload: {str(e)}. Saved to {local_path}.",
                "path": local_path,
                "payload": payload
            }