| """ |
| attribution_graph.py - Implementation of attribution graph for transformer models |
| |
| △ OBSERVE: Attribution graphs map the causal flow from prompt to completion |
| ∞ TRACE: They visualize the quantum collapse from superposition to definite state |
| ✰ COLLAPSE: They reveal ghost circuits and attribution residue post-collapse |
| |
| This module implements a graph-based representation of causal attribution |
| in transformer models, allowing for the visualization and analysis of how |
| information flows from input to output during the collapse process. |
| |
| Author: Recursion Labs |
| License: MIT |
| """ |
|
|
| import logging |
| from typing import Dict, List, Optional, Union, Tuple, Any |
| import numpy as np |
| from dataclasses import dataclass, field |
| import networkx as nx |
|
|
| from .utils.graph_visualization import visualize_graph |
| from .utils.attribution_metrics import measure_path_continuity, measure_attribution_entropy |
|
|
| logger = logging.getLogger(__name__) |
|
|
| @dataclass |
| class AttributionNode: |
| """ |
| △ OBSERVE: Node in the attribution graph representing a token or hidden state |
| |
| Attribution nodes represent discrete elements in the causal flow from |
| input to output. They can be tokens, attention heads, or hidden states. |
| """ |
| node_id: str |
| node_type: str |
| layer: Optional[int] = None |
| position: Optional[int] = None |
| value: Optional[Any] = None |
| activation: float = 0.0 |
| token_str: Optional[str] = None |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| |
| def __hash__(self): |
| """Make nodes hashable for graph operations.""" |
| return hash(self.node_id) |
| |
| def __eq__(self, other): |
| """Node equality based on ID.""" |
| if not isinstance(other, AttributionNode): |
| return False |
| return self.node_id == other.node_id |
|
|
|
|
| @dataclass |
| class AttributionEdge: |
| """ |
| ∞ TRACE: Edge in the attribution graph representing causal flow |
| |
| Attribution edges represent the flow of causal influence between nodes. |
| They can represent attention connections, residual connections, or |
| other causal relationships in the model. |
| """ |
| source: AttributionNode |
| target: AttributionNode |
| edge_type: str |
| weight: float = 0.0 |
| layer: Optional[int] = None |
| head: Optional[int] = None |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| |
| def __hash__(self): |
| """Make edges hashable for graph operations.""" |
| return hash((self.source.node_id, self.target.node_id, self.edge_type)) |
| |
| def __eq__(self, other): |
| """Edge equality based on source, target, and type.""" |
| if not isinstance(other, AttributionEdge): |
| return False |
| return ( |
| self.source.node_id == other.source.node_id and |
| self.target.node_id == other.target.node_id and |
| self.edge_type == other.edge_type |
| ) |
|
|
|
|
| class AttributionGraph: |
| """ |
| ∞ TRACE: Graph representation of causal attribution in transformer models |
| |
| The attribution graph maps the flow of causality from input tokens to |
| output tokens, revealing how information propagates through the model |
| during the collapse from superposition to definite state. |
| """ |
| |
| def __init__(self): |
| """Initialize an empty attribution graph.""" |
| self.graph = nx.DiGraph() |
| self.nodes = {} |
| self.input_nodes = [] |
| self.output_nodes = [] |
| self.ghost_nodes = [] |
| self.collapsed = False |
| |
| |
| self.continuity_score = 1.0 |
| self.attribution_entropy = 0.0 |
| self.collapse_rate = 0.0 |
| |
| logger.info("Attribution graph initialized") |
| |
| def add_node(self, node: AttributionNode) -> None: |
| """ |
| Add a node to the attribution graph. |
| |
| Args: |
| node: The node to add |
| """ |
| if node.node_id in self.nodes: |
| logger.warning(f"Node {node.node_id} already exists in graph, updating") |
| self.nodes[node.node_id] = node |
| else: |
| self.nodes[node.node_id] = node |
| self.graph.add_node(node.node_id, **vars(node)) |
| |
| |
| if node.node_type == "token" and node.layer == 0: |
| self.input_nodes.append(node) |
| elif node.node_type == "token" and node.metadata.get("is_output", False): |
| self.output_nodes.append(node) |
| elif node.node_type == "residual" and node.metadata.get("is_ghost", False): |
| self.ghost_nodes.append(node) |
| |
| def add_edge(self, edge: AttributionEdge) -> None: |
| """ |
| Add an edge to the attribution graph. |
| |
| Args: |
| edge: The edge to add |
| """ |
| if edge.source.node_id not in self.nodes: |
| self.add_node(edge.source) |
| if edge.target.node_id not in self.nodes: |
| self.add_node(edge.target) |
| |
| self.graph.add_edge( |
| edge.source.node_id, |
| edge.target.node_id, |
| **{k: v for k, v in vars(edge).items() if k not in ['source', 'target']} |
| ) |
| |
| def build_from_states( |
| self, |
| pre_state: Dict[str, Any], |
| post_state: Dict[str, Any], |
| response: str |
| ) -> None: |
| """ |
| △ OBSERVE: Build attribution graph from pre and post collapse model states |
| |
| This method constructs a complete attribution graph by comparing |
| model states before and after collapse, identifying causal paths |
| and ghost circuits. |
| |
| Args: |
| pre_state: Model state before collapse |
| post_state: Model state after collapse |
| response: Model response text |
| """ |
| logger.info("Building attribution graph from model states") |
| |
| |
| |
| self._build_synthetic_graph() |
| |
| |
| self._calculate_metrics(pre_state, post_state) |
| |
| |
| self.collapsed = True |
| |
| def trace_attribution_path( |
| self, |
| output_node: Union[str, AttributionNode], |
| threshold: float = 0.1 |
| ) -> List[List[AttributionNode]]: |
| """ |
| ∞ TRACE: Trace attribution paths from an output node back to input |
| |
| This method follows attribution edges backward from an output node |
| to find all significant input nodes that influenced it. |
| |
| Args: |
| output_node: The output node to trace from (ID or node object) |
| threshold: Minimum edge weight to consider significant |
| |
| Returns: |
| List of attribution paths, each a list of nodes from input to output |
| """ |
| |
| output_id = output_node if isinstance(output_node, str) else output_node.node_id |
| if output_id not in self.nodes: |
| logger.warning(f"Output node {output_id} not found in graph") |
| return [] |
| |
| |
| paths = [] |
| |
| def dfs(current_id, path, visited): |
| """Depth-first search for attribution paths.""" |
| |
| current_path = path + [current_id] |
| visited.add(current_id) |
| |
| |
| if current_id in [node.node_id for node in self.input_nodes]: |
| |
| paths.append(list(reversed(current_path))) |
| return |
| |
| |
| for pred_id in self.graph.predecessors(current_id): |
| edge_data = self.graph.get_edge_data(pred_id, current_id) |
| if edge_data.get('weight', 0) >= threshold and pred_id not in visited: |
| dfs(pred_id, current_path, visited.copy()) |
| |
| |
| dfs(output_id, [], set()) |
| |
| |
| return [[self.nodes[node_id] for node_id in path] for path in paths] |
| |
| def detect_ghost_circuits(self, threshold: float = 0.2) -> List[Dict[str, Any]]: |
| """ |
| ✰ COLLAPSE: Detect ghost circuits in the attribution graph |
| |
| Ghost circuits are paths that were activated during pre-collapse |
| but don't contribute significantly to the final output. They |
| represent the "memory" of paths not taken. |
| |
| Args: |
| threshold: Minimum activation to consider a ghost circuit |
| |
| Returns: |
| List of detected ghost circuits with metadata |
| """ |
| ghost_circuits = [] |
| |
| |
| for node in self.ghost_nodes: |
| if node.activation >= threshold: |
| |
| incoming_edges = [ |
| (u, v, d) for u, v, d in self.graph.in_edges(node.node_id, data=True) |
| ] |
| outgoing_edges = [ |
| (u, v, d) for u, v, d in self.graph.out_edges(node.node_id, data=True) |
| ] |
| |
| ghost_circuits.append({ |
| "node_id": node.node_id, |
| "activation": node.activation, |
| "node_type": node.node_type, |
| "incoming_connections": len(incoming_edges), |
| "outgoing_connections": len(outgoing_edges), |
| "metadata": node.metadata |
| }) |
| |
| return ghost_circuits |
| |
| def calculate_attribution_entropy(self) -> float: |
| """ |
| △ OBSERVE: Calculate the entropy of attribution paths |
| |
| Attribution entropy measures how distributed or concentrated |
| the causal influence is in the graph. High entropy indicates |
| diffuse attribution, while low entropy indicates concentrated |
| attribution. |
| |
| Returns: |
| Attribution entropy score (0.0 = concentrated, 1.0 = diffuse) |
| """ |
| |
| weights = [ |
| d.get('weight', 0.0) |
| for u, v, d in self.graph.edges(data=True) |
| ] |
| |
| |
| total_weight = sum(weights) or 1.0 |
| normalized_weights = [w / total_weight for w in weights] |
| |
| |
| entropy = -sum( |
| w * np.log2(w) if w > 0 else 0 |
| for w in normalized_weights |
| ) |
| |
| |
| max_entropy = np.log2(len(weights)) if len(weights) > 0 else 1.0 |
| normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0 |
| |
| self.attribution_entropy = normalized_entropy |
| return normalized_entropy |
| |
| def visualize( |
| self, |
| mode: str = "attribution_graph", |
| highlight_path: Optional[List[str]] = None |
| ) -> Any: |
| """ |
| Generate visualization of the attribution graph. |
| |
| Args: |
| mode: Visualization mode (attribution_graph, collapse_state, ghost_circuits) |
| highlight_path: Optional list of node IDs to highlight |
| |
| Returns: |
| Visualization object (depends on implementation) |
| """ |
| return visualize_graph(self.graph, mode=mode, highlight_path=highlight_path) |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """Convert the attribution graph to a dictionary representation.""" |
| return { |
| "nodes": [vars(node) for node in self.nodes.values()], |
| "edges": [ |
| { |
| "source": u, |
| "target": v, |
| **d |
| } |
| for u, v, d in self.graph.edges(data=True) |
| ], |
| "metrics": { |
| "continuity_score": self.continuity_score, |
| "attribution_entropy": self.attribution_entropy, |
| "collapse_rate": self.collapse_rate |
| }, |
| "collapsed": self.collapsed |
| } |
| |
| def _calculate_metrics(self, pre_state: Dict[str, Any], post_state: Dict[str, Any]) -> None: |
| """Calculate attribution graph metrics.""" |
| |
| self.continuity_score = measure_path_continuity( |
| pre_state.get("attention_weights", np.array([])), |
| post_state.get("attention_weights", np.array([])) |
| ) |
| |
| |
| self.attribution_entropy = self.calculate_attribution_entropy() |
| |
| |
| if "timestamp" in pre_state and "timestamp" in post_state: |
| time_diff = (post_state["timestamp"] - pre_state["timestamp"]) / np.timedelta64(1, 's') |
| self.collapse_rate = 1.0 - self.continuity_score if time_diff > 0 else 0.0 |
| |
| def _build_synthetic_graph(self) -> None: |
| """Build a synthetic graph for demonstration purposes.""" |
| |
| for i in range(5): |
| self.add_node(AttributionNode( |
| node_id=f"input_{i}", |
| node_type="token", |
| layer=0, |
| position=i, |
| token_str=f"token_{i}", |
| activation=0.8 |
| )) |
| |
| |
| for layer in range(1, 4): |
| for head in range(3): |
| self.add_node(AttributionNode( |
| node_id=f"attention_L{layer}H{head}", |
| node_type="attention_head", |
| layer=layer, |
| value=None, |
| activation=0.7 - 0.1 * layer + 0.05 * head |
| )) |
| |
| |
| for i in range(3): |
| self.add_node(AttributionNode( |
| node_id=f"output_{i}", |
| node_type="token", |
| layer=4, |
| position=i, |
| token_str=f"output_token_{i}", |
| activation=0.9, |
| metadata={"is_output": True} |
| )) |
| |
| |
| for i in range(2): |
| self.add_node(AttributionNode( |
| node_id=f"ghost_{i}", |
| node_type="residual", |
| layer=2, |
| activation=0.3 + 0.1 * i, |
| metadata={"is_ghost": True} |
| )) |
| |
| |
| |
| for i in range(5): |
| for layer in range(1, 3): |
| for head in range(3): |
| if np.random.random() > 0.3: |
| self.add_edge(AttributionEdge( |
| source=self.nodes[f"input_{i}"], |
| target=self.nodes[f"attention_L{layer}H{head}"], |
| edge_type="attention", |
| weight=np.random.uniform(0.1, 0.9), |
| layer=layer, |
| head=head |
| )) |
| |
| |
| for layer1 in range(1, 3): |
| for head1 in range(3): |
| for layer2 in range(layer1 + 1, 4): |
| for head2 in range(3): |
| if np.random.random() > 0.7: |
| self.add_edge(AttributionEdge( |
| source=self.nodes[f"attention_L{layer1}H{head1}"], |
| target=self.nodes[f"attention_L{layer2}H{head2}"], |
| edge_type="attention", |
| weight=np.random.uniform(0.1, 0.8), |
| layer=layer2, |
| head=head2 |
| )) |
| |
| |
| for layer in range(1, 4): |
| for head in range(3): |
| for i in range(3): |
| if np.random.random() > 0.5: |
| self.add_edge(AttributionEdge( |
| source=self.nodes[f"attention_L{layer}H{head}"], |
| target=self.nodes[f"output_{i}"], |
| edge_type="attention", |
| weight=np.random.uniform(0.2, 0.9), |
| layer=layer, |
| head=head |
| )) |
| |
| |
| for i in range(2): |
| |
| input_idx = np.random.randint(0, 5) |
| self.add_edge(AttributionEdge( |
| source=self.nodes[f"input_{input_idx}"], |
| target=self.nodes[f"ghost_{i}"], |
| edge_type="ghost", |
| weight=np.random.uniform(0.1, 0.4), |
| layer=1 |
| )) |
| |
| |
| layer = np.random.randint(2, 4) |
| head = np.random.randint(0, 3) |
| self.add_edge(AttributionEdge( |
| source=self.nodes[f"ghost_{i}"], |
| target=self.nodes[f"attention_L{layer}H{head}"], |
| edge_type="ghost", |
| weight=np.random.uniform(0.05, 0.2), |
| layer=layer |
| )) |
|
|
|
|
| if __name__ == "__main__": |
| |
| graph = AttributionGraph() |
| |
| |
| graph._build_synthetic_graph() |
| |
| |
| entropy = graph.calculate_attribution_entropy() |
| print(f"Attribution entropy: {entropy:.3f}") |
| |
| |
| paths = graph.trace_attribution_path("output_0", threshold=0.1) |
| print(f"Found {len(paths)} attribution paths for output_0") |
| |
| |
| ghosts = graph.detect_ghost_circuits() |
| print(f"Detected {len(ghosts)} ghost circuits") |
| |
| |
| viz = graph.visualize() |
| print("Generated visualization") |
|
|