Spaces:
Sleeping
Sleeping
File size: 6,192 Bytes
bfb7184 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """CLA Metadata Layer — Cross-Layer KV Cache Sharing hints for vLLM.
Based on:
- CLA (NeurIPS 2024): 2x KV cache reduction by sharing KVs between
adjacent layer groups with negligible accuracy loss.
- NAACL 2025 systematic study: pairing queries of ALL layers with KVs of
UPPER layers outperforms bottom-layer sharing at aggressive compression.
- LCKV (ACL 2024): Layer-Condensed KV, queries of all layers share KVs of
only the top layer.
V4.0 CHANGES: New module for inference-time CLA hint injection.
"""
from dataclasses import dataclass
from typing import Optional
# Non-thinking roles (no chain-of-thought, can benefit from CLA)
NON_THOUGHT_ROLES = frozenset({"retriever", "summarizer", "formatter", "reviewer", "classifier"})
@dataclass
class CLAGroupConfig:
"""Configuration for CLA layer grouping strategy."""
group_size: int = 2 # layers per group (2 = 2x reduction)
sharing_direction: str = "upper" # "upper" | "lower" per NAACL 2025
thinking_mode_bypass: bool = True # never apply CLA in thinking mode
min_layer: int = 0 # skip bottom N layers (attention sinks)
max_layer: int = 64 # skip above this layer index
@dataclass
class CLAHint:
"""Metadata hint for vLLM attention backend to share KV across layers."""
agent_id: str
model_id: str
layer_groups: list[tuple[int, int]] # (start_layer, shared_kv_layer)
estimated_vram_reduction_pct: float # 0.0–0.5 for group_size=2
is_thinking_mode: bool # if True, hint is IGNORED by backend
group_config: CLAGroupConfig
class CLAMetadataLayer:
"""
Computes CLA metadata hints for agents based on their role and mode.
Usage:
cla = CLAMetadataLayer(CLAGroupConfig(group_size=2))
hint = cla.emit_hint("agent1", "Qwen3.6-35B-A22B", is_thinking_mode=False, agent_role="retriever")
"""
def __init__(self, config: CLAGroupConfig = CLAGroupConfig()):
self._config = config
def compute_layer_groups(
self,
model_layer_count: int,
agent_role: str,
) -> list[tuple[int, int]]:
"""
Compute layer sharing groups per NAACL 2025 'upper-layer' strategy.
For group_size=2 and 64 layers:
[(0,1), (2,3), (4,5), ..., (62,63)]
→ layer 0 queries use KV of layer 1, etc.
Skip min_layer bottom layers to protect attention sinks.
Args:
model_layer_count: Total number of transformer layers in model
agent_role: Agent role (e.g., "retriever", "summarizer") determines
whether this agent is in thinking or non-thinking mode
Returns:
List of (start_layer, shared_kv_layer) tuples
"""
# Check if role is thinking or non-thinking
is_non_thinking = agent_role in NON_THOUGHT_ROLES
# Don't compute groups for thinking-mode agents (they bypass CLA)
if not is_non_thinking:
return []
groups = []
cfg = self._config
# Start from min_layer, go up to max_layer, step by group_size
for start in range(cfg.min_layer, min(cfg.max_layer, model_layer_count), cfg.group_size):
end = min(start + cfg.group_size - 1, model_layer_count - 1)
if cfg.sharing_direction == "upper":
# NAACL 2025: queries of layer i share KV of layer i+1 (upper layer)
shared_kv_layer = end
else:
# Alternative: share KV of lower layer
shared_kv_layer = start
groups.append((start, shared_kv_layer))
return groups
def emit_hint(
self,
agent_id: str,
model_id: str,
is_thinking_mode: bool,
model_layer_count: int = 64,
agent_role: str = "default",
) -> CLAHint:
"""
Emit a CLAHint for a given agent.
If is_thinking_mode=True and thinking_mode_bypass is True,
returns empty layer_groups and 0.0 vram_reduction.
Args:
agent_id: Unique agent identifier
model_id: Model name (e.g., "Qwen3.6-35B-A22B")
is_thinking_mode: True if agent uses chain-of-thought reasoning
model_layer_count: Number of transformer layers
agent_role: Agent role for CLA eligibility determination
Returns:
CLAHint with layer_groups and estimated VRAM reduction
"""
# Bypass if thinking mode and config says to bypass
if is_thinking_mode and self._config.thinking_mode_bypass:
return CLAHint(
agent_id=agent_id,
model_id=model_id,
layer_groups=[],
estimated_vram_reduction_pct=0.0,
is_thinking_mode=True,
group_config=self._config,
)
layer_groups = self.compute_layer_groups(model_layer_count, agent_role)
vram_reduction = self.estimated_vram_reduction(layer_groups)
return CLAHint(
agent_id=agent_id,
model_id=model_id,
layer_groups=layer_groups,
estimated_vram_reduction_pct=vram_reduction,
is_thinking_mode=is_thinking_mode,
group_config=self._config,
)
def estimated_vram_reduction(self, layer_groups: list) -> float:
"""
Estimate VRAM reduction factor from layer groups.
group_size=2 → 50% of layers share KV → ~0.5 * KV_per_layer savings.
Conservative estimate since actual savings depend on attention head count.
Args:
layer_groups: Output of compute_layer_groups()
Returns:
Float 0.0–0.5 representing VRAM fraction saved
"""
if not layer_groups:
return 0.0
# Each group shares 1 layer's KV across group_size layers
# Fraction saved = (group_size - 1) / group_size
# For group_size=2: (2-1)/2 = 0.5 (50% savings)
cfg = self._config
return (cfg.group_size - 1) / cfg.group_size |