Spaces:
Running
Running
File size: 1,551 Bytes
e2614dc 4aa19e7 e2614dc 11dbbc6 e2614dc 11dbbc6 731ae64 e2614dc 731ae64 e2614dc 11dbbc6 731ae64 e2614dc 731ae64 e2614dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import torch
from jaxtyping import Float
from typing import Dict, List
import matplotlib.pyplot as plt
import seaborn as sns
class LogitAttributionEngine:
"""
Calculates the Direct Logit Attribution (DLA) of transformer components.
"""
def __init__(self, model):
self.model = model
def calculate_dla(
self,
cache,
target_logit_index: int,
token_index: int = -2
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
"""
Calculates DLA for each head: Activation @ W_O @ W_U [target_logit]
"""
n_layers = self.model.cfg.n_layers
n_heads = self.model.cfg.n_heads
# Weight for target action prediction
W_U = self.model.predict_action[0].weight[target_logit_index]
dla_results = torch.zeros((n_layers, n_heads))
for layer in range(n_layers):
# [batch, pos, head, d_model]
head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
# Use token at specified index
last_token_output = head_outputs[0, token_index]
dla_results[layer] = torch.matmul(last_token_output, W_U)
return dla_results
def plot_dla(self, dla_results: torch.Tensor, title="Direct Logit Attribution"):
plt.figure(figsize=(10, 6))
sns.heatmap(dla_results.detach().cpu().numpy(), annot=True, fmt=".2f", cmap="RdBu_r", center=0)
plt.xlabel("Head")
plt.ylabel("Layer")
plt.title(title)
plt.show()
|