Spaces:
Running
Running
| import torch | |
| from jaxtyping import Float | |
| from typing import Dict, List | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| class LogitAttributionEngine: | |
| """ | |
| Calculates the Direct Logit Attribution (DLA) of transformer components. | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| def calculate_dla( | |
| self, | |
| cache, | |
| target_logit_index: int, | |
| token_index: int = -2 | |
| ) -> Dict[str, Float[torch.Tensor, "layer head"]]: | |
| """ | |
| Calculates DLA for each head: Activation @ W_O @ W_U [target_logit] | |
| """ | |
| n_layers = self.model.cfg.n_layers | |
| n_heads = self.model.cfg.n_heads | |
| # Weight for target action prediction | |
| W_U = self.model.predict_action[0].weight[target_logit_index] | |
| dla_results = torch.zeros((n_layers, n_heads)) | |
| for layer in range(n_layers): | |
| # [batch, pos, head, d_model] | |
| head_outputs = cache[f"blocks.{layer}.attn.hook_result"] | |
| # Use token at specified index | |
| last_token_output = head_outputs[0, token_index] | |
| dla_results[layer] = torch.matmul(last_token_output, W_U) | |
| return dla_results | |
| def plot_dla(self, dla_results: torch.Tensor, title="Direct Logit Attribution"): | |
| plt.figure(figsize=(10, 6)) | |
| sns.heatmap(dla_results.detach().cpu().numpy(), annot=True, fmt=".2f", cmap="RdBu_r", center=0) | |
| plt.xlabel("Head") | |
| plt.ylabel("Layer") | |
| plt.title(title) | |
| plt.show() | |