File size: 2,437 Bytes
8577352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from typing import List, Dict, Optional
import requests

class NLAExplainer:
    """
    Natural Language Autoencoder (NLA) Explainer.
    Uses an LLM to auto-label SAE features based on activation patterns.
    """
    def __init__(self, api_key: Optional[str] = None, model_name: str = "gpt-4-turbo"):
        self.api_key = api_key
        self.model_name = model_name
        self.feature_labels: Dict[int, str] = {}

    def generate_label(
        self, 
        feature_id: int, 
        top_activations: List[Dict], 
        context_description: str = "MiniGrid environment agent state"
    ) -> str:
        """
        Generates a natural language label for a specific SAE feature.
        In a real scenario, this would call an LLM API.
        """
        if not self.api_key:
            # Mock labeling for demonstration if no API key is provided
            label = f"Mock Feature {feature_id}: Activates on {context_description} pattern"
            self.feature_labels[feature_id] = label
            return label

        prompt = self._build_prompt(feature_id, top_activations, context_description)
        
        # This is a placeholder for a real API call (e.g., OpenAI, Anthropic, or custom)
        # label = self._call_llm_api(prompt)
        label = f"Auto-labeled Feature {feature_id}" 
        
        self.feature_labels[feature_id] = label
        return label

    def _build_prompt(self, feature_id: int, top_activations: List[Dict], context: str) -> str:
        """Constructs the prompt for the LLM explainer."""
        examples = "\n".join([f"- State: {a['state']}, Activation: {a['value']:.4f}" for a in top_activations])
        return (
            f"I have a Sparse Autoencoder feature (ID: {feature_id}) trained on a Decision Transformer. "
            f"The context is: {context}.\n"
            f"Here are the top activations for this feature:\n{examples}\n"
            "What is the most likely semantic meaning of this feature? Provide a concise label."
        )

    def get_label(self, feature_id: int) -> str:
        return self.feature_labels.get(feature_id, f"Unlabeled Feature {feature_id}")

    def bulk_label(self, feature_ids: List[int], activation_data: Dict[int, List[Dict]]):
        """Labels multiple features in sequence."""
        for fid in feature_ids:
            if fid in activation_data:
                self.generate_label(fid, activation_data[fid])