Spaces:
Running
Running
| import torch | |
| from typing import List, Dict, Optional | |
| import requests | |
| class NLAExplainer: | |
| """ | |
| Natural Language Autoencoder (NLA) Explainer. | |
| Uses an LLM to auto-label SAE features based on activation patterns. | |
| """ | |
| def __init__(self, api_key: Optional[str] = None, model_name: str = "gpt-4-turbo"): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self.feature_labels: Dict[int, str] = {} | |
| def generate_label( | |
| self, | |
| feature_id: int, | |
| top_activations: List[Dict], | |
| context_description: str = "MiniGrid environment agent state" | |
| ) -> str: | |
| """ | |
| Generates a natural language label for a specific SAE feature. | |
| In a real scenario, this would call an LLM API. | |
| """ | |
| if not self.api_key: | |
| # Mock labeling for demonstration if no API key is provided | |
| label = f"Mock Feature {feature_id}: Activates on {context_description} pattern" | |
| self.feature_labels[feature_id] = label | |
| return label | |
| prompt = self._build_prompt(feature_id, top_activations, context_description) | |
| # This is a placeholder for a real API call (e.g., OpenAI, Anthropic, or custom) | |
| # label = self._call_llm_api(prompt) | |
| label = f"Auto-labeled Feature {feature_id}" | |
| self.feature_labels[feature_id] = label | |
| return label | |
| def _build_prompt(self, feature_id: int, top_activations: List[Dict], context: str) -> str: | |
| """Constructs the prompt for the LLM explainer.""" | |
| examples = "\n".join([f"- State: {a['state']}, Activation: {a['value']:.4f}" for a in top_activations]) | |
| return ( | |
| f"I have a Sparse Autoencoder feature (ID: {feature_id}) trained on a Decision Transformer. " | |
| f"The context is: {context}.\n" | |
| f"Here are the top activations for this feature:\n{examples}\n" | |
| "What is the most likely semantic meaning of this feature? Provide a concise label." | |
| ) | |
| def get_label(self, feature_id: int) -> str: | |
| return self.feature_labels.get(feature_id, f"Unlabeled Feature {feature_id}") | |
| def bulk_label(self, feature_ids: List[int], activation_data: Dict[int, List[Dict]]): | |
| """Labels multiple features in sequence.""" | |
| for fid in feature_ids: | |
| if fid in activation_data: | |
| self.generate_label(fid, activation_data[fid]) | |