File size: 1,202 Bytes
e2614dc
 
 
 
 
731ae64
e2614dc
 
 
 
 
 
731ae64
e2614dc
 
 
 
 
 
 
731ae64
 
e2614dc
 
 
731ae64
e2614dc
 
 
 
 
 
731ae64
e2614dc
731ae64
e2614dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from typing import List, Tuple

class InductionScanner:
    """
    Identifies induction heads that attend to tokens following a previous occurrence.
    """
    def __init__(self, model):
        self.model = model

    def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
        """
        Scans heads for induction behavior.
        """
        n_layers = self.model.cfg.n_layers
        n_heads = self.model.cfg.n_heads
        
        induction_heads = []

        for layer in range(n_layers):
            # [batch, head, query_pos, key_pos]
            attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
            
            for head in range(n_heads):
                score = self._calculate_induction_score(attn_pattern[0, head])
                if score > 0.5:
                    induction_heads.append((layer, head))
        
        return induction_heads

    def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
        """
        Heuristic check for shifted diagonal attention.
        """
        # Checks if attention is shifted by 1 relative to diagonal.
        return torch.diagonal(pattern, offset=-1).mean().item()