File size: 6,192 Bytes
bfb7184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""CLA Metadata Layer — Cross-Layer KV Cache Sharing hints for vLLM.

Based on:
- CLA (NeurIPS 2024): 2x KV cache reduction by sharing KVs between
  adjacent layer groups with negligible accuracy loss.
- NAACL 2025 systematic study: pairing queries of ALL layers with KVs of
  UPPER layers outperforms bottom-layer sharing at aggressive compression.
- LCKV (ACL 2024): Layer-Condensed KV, queries of all layers share KVs of
  only the top layer.

V4.0 CHANGES: New module for inference-time CLA hint injection.
"""
from dataclasses import dataclass
from typing import Optional

# Non-thinking roles (no chain-of-thought, can benefit from CLA)
NON_THOUGHT_ROLES = frozenset({"retriever", "summarizer", "formatter", "reviewer", "classifier"})


@dataclass
class CLAGroupConfig:
    """Configuration for CLA layer grouping strategy."""
    group_size: int = 2          # layers per group (2 = 2x reduction)
    sharing_direction: str = "upper"  # "upper" | "lower" per NAACL 2025
    thinking_mode_bypass: bool = True  # never apply CLA in thinking mode
    min_layer: int = 0           # skip bottom N layers (attention sinks)
    max_layer: int = 64          # skip above this layer index


@dataclass
class CLAHint:
    """Metadata hint for vLLM attention backend to share KV across layers."""
    agent_id: str
    model_id: str
    layer_groups: list[tuple[int, int]]  # (start_layer, shared_kv_layer)
    estimated_vram_reduction_pct: float  # 0.0–0.5 for group_size=2
    is_thinking_mode: bool       # if True, hint is IGNORED by backend
    group_config: CLAGroupConfig


class CLAMetadataLayer:
    """
    Computes CLA metadata hints for agents based on their role and mode.
    
    Usage:
        cla = CLAMetadataLayer(CLAGroupConfig(group_size=2))
        hint = cla.emit_hint("agent1", "Qwen3.6-35B-A22B", is_thinking_mode=False, agent_role="retriever")
    """
    
    def __init__(self, config: CLAGroupConfig = CLAGroupConfig()):
        self._config = config
    
    def compute_layer_groups(
        self,
        model_layer_count: int,
        agent_role: str,
    ) -> list[tuple[int, int]]:
        """
        Compute layer sharing groups per NAACL 2025 'upper-layer' strategy.
        
        For group_size=2 and 64 layers:
            [(0,1), (2,3), (4,5), ..., (62,63)]
            → layer 0 queries use KV of layer 1, etc.
        Skip min_layer bottom layers to protect attention sinks.
        
        Args:
            model_layer_count: Total number of transformer layers in model
            agent_role: Agent role (e.g., "retriever", "summarizer") determines
                       whether this agent is in thinking or non-thinking mode
        
        Returns:
            List of (start_layer, shared_kv_layer) tuples
        """
        # Check if role is thinking or non-thinking
        is_non_thinking = agent_role in NON_THOUGHT_ROLES
        
        # Don't compute groups for thinking-mode agents (they bypass CLA)
        if not is_non_thinking:
            return []
        
        groups = []
        cfg = self._config
        # Start from min_layer, go up to max_layer, step by group_size
        for start in range(cfg.min_layer, min(cfg.max_layer, model_layer_count), cfg.group_size):
            end = min(start + cfg.group_size - 1, model_layer_count - 1)
            if cfg.sharing_direction == "upper":
                # NAACL 2025: queries of layer i share KV of layer i+1 (upper layer)
                shared_kv_layer = end
            else:
                # Alternative: share KV of lower layer
                shared_kv_layer = start
            groups.append((start, shared_kv_layer))
        
        return groups
    
    def emit_hint(
        self,
        agent_id: str,
        model_id: str,
        is_thinking_mode: bool,
        model_layer_count: int = 64,
        agent_role: str = "default",
    ) -> CLAHint:
        """
        Emit a CLAHint for a given agent.
        
        If is_thinking_mode=True and thinking_mode_bypass is True,
        returns empty layer_groups and 0.0 vram_reduction.
        
        Args:
            agent_id: Unique agent identifier
            model_id: Model name (e.g., "Qwen3.6-35B-A22B")
            is_thinking_mode: True if agent uses chain-of-thought reasoning
            model_layer_count: Number of transformer layers
            agent_role: Agent role for CLA eligibility determination
        
        Returns:
            CLAHint with layer_groups and estimated VRAM reduction
        """
        # Bypass if thinking mode and config says to bypass
        if is_thinking_mode and self._config.thinking_mode_bypass:
            return CLAHint(
                agent_id=agent_id,
                model_id=model_id,
                layer_groups=[],
                estimated_vram_reduction_pct=0.0,
                is_thinking_mode=True,
                group_config=self._config,
            )
        
        layer_groups = self.compute_layer_groups(model_layer_count, agent_role)
        vram_reduction = self.estimated_vram_reduction(layer_groups)
        
        return CLAHint(
            agent_id=agent_id,
            model_id=model_id,
            layer_groups=layer_groups,
            estimated_vram_reduction_pct=vram_reduction,
            is_thinking_mode=is_thinking_mode,
            group_config=self._config,
        )
    
    def estimated_vram_reduction(self, layer_groups: list) -> float:
        """
        Estimate VRAM reduction factor from layer groups.
        
        group_size=2 → 50% of layers share KV → ~0.5 * KV_per_layer savings.
        Conservative estimate since actual savings depend on attention head count.
        
        Args:
            layer_groups: Output of compute_layer_groups()
        
        Returns:
            Float 0.0–0.5 representing VRAM fraction saved
        """
        if not layer_groups:
            return 0.0
        
        # Each group shares 1 layer's KV across group_size layers
        # Fraction saved = (group_size - 1) / group_size
        # For group_size=2: (2-1)/2 = 0.5 (50% savings)
        cfg = self._config
        return (cfg.group_size - 1) / cfg.group_size