File size: 1,551 Bytes
e2614dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa19e7
e2614dc
 
11dbbc6
e2614dc
 
 
 
11dbbc6
731ae64
e2614dc
 
 
 
731ae64
 
e2614dc
11dbbc6
731ae64
e2614dc
731ae64
e2614dc
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from jaxtyping import Float
from typing import Dict, List
import matplotlib.pyplot as plt
import seaborn as sns

class LogitAttributionEngine:
    """
    Calculates the Direct Logit Attribution (DLA) of transformer components.
    """
    def __init__(self, model):
        self.model = model

    def calculate_dla(
        self, 
        cache, 
        target_logit_index: int,
        token_index: int = -2
    ) -> Dict[str, Float[torch.Tensor, "layer head"]]:
        """
        Calculates DLA for each head: Activation @ W_O @ W_U [target_logit]
        """
        n_layers = self.model.cfg.n_layers
        n_heads = self.model.cfg.n_heads
        
        # Weight for target action prediction
        W_U = self.model.predict_action[0].weight[target_logit_index] 

        dla_results = torch.zeros((n_layers, n_heads))

        for layer in range(n_layers):
            # [batch, pos, head, d_model]
            head_outputs = cache[f"blocks.{layer}.attn.hook_result"] 
            
            # Use token at specified index
            last_token_output = head_outputs[0, token_index] 
            
            dla_results[layer] = torch.matmul(last_token_output, W_U)

        return dla_results

    def plot_dla(self, dla_results: torch.Tensor, title="Direct Logit Attribution"):
        plt.figure(figsize=(10, 6))
        sns.heatmap(dla_results.detach().cpu().numpy(), annot=True, fmt=".2f", cmap="RdBu_r", center=0)
        plt.xlabel("Head")
        plt.ylabel("Layer")
        plt.title(title)
        plt.show()