Spaces:
Sleeping
Sleeping
| """CLA Metadata Layer — Cross-Layer KV Cache Sharing hints for vLLM. | |
| Based on: | |
| - CLA (NeurIPS 2024): 2x KV cache reduction by sharing KVs between | |
| adjacent layer groups with negligible accuracy loss. | |
| - NAACL 2025 systematic study: pairing queries of ALL layers with KVs of | |
| UPPER layers outperforms bottom-layer sharing at aggressive compression. | |
| - LCKV (ACL 2024): Layer-Condensed KV, queries of all layers share KVs of | |
| only the top layer. | |
| V4.0 CHANGES: New module for inference-time CLA hint injection. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| # Non-thinking roles (no chain-of-thought, can benefit from CLA) | |
| NON_THOUGHT_ROLES = frozenset({"retriever", "summarizer", "formatter", "reviewer", "classifier"}) | |
| class CLAGroupConfig: | |
| """Configuration for CLA layer grouping strategy.""" | |
| group_size: int = 2 # layers per group (2 = 2x reduction) | |
| sharing_direction: str = "upper" # "upper" | "lower" per NAACL 2025 | |
| thinking_mode_bypass: bool = True # never apply CLA in thinking mode | |
| min_layer: int = 0 # skip bottom N layers (attention sinks) | |
| max_layer: int = 64 # skip above this layer index | |
| class CLAHint: | |
| """Metadata hint for vLLM attention backend to share KV across layers.""" | |
| agent_id: str | |
| model_id: str | |
| layer_groups: list[tuple[int, int]] # (start_layer, shared_kv_layer) | |
| estimated_vram_reduction_pct: float # 0.0–0.5 for group_size=2 | |
| is_thinking_mode: bool # if True, hint is IGNORED by backend | |
| group_config: CLAGroupConfig | |
| class CLAMetadataLayer: | |
| """ | |
| Computes CLA metadata hints for agents based on their role and mode. | |
| Usage: | |
| cla = CLAMetadataLayer(CLAGroupConfig(group_size=2)) | |
| hint = cla.emit_hint("agent1", "Qwen3.6-35B-A22B", is_thinking_mode=False, agent_role="retriever") | |
| """ | |
| def __init__(self, config: CLAGroupConfig = CLAGroupConfig()): | |
| self._config = config | |
| def compute_layer_groups( | |
| self, | |
| model_layer_count: int, | |
| agent_role: str, | |
| ) -> list[tuple[int, int]]: | |
| """ | |
| Compute layer sharing groups per NAACL 2025 'upper-layer' strategy. | |
| For group_size=2 and 64 layers: | |
| [(0,1), (2,3), (4,5), ..., (62,63)] | |
| → layer 0 queries use KV of layer 1, etc. | |
| Skip min_layer bottom layers to protect attention sinks. | |
| Args: | |
| model_layer_count: Total number of transformer layers in model | |
| agent_role: Agent role (e.g., "retriever", "summarizer") determines | |
| whether this agent is in thinking or non-thinking mode | |
| Returns: | |
| List of (start_layer, shared_kv_layer) tuples | |
| """ | |
| # Check if role is thinking or non-thinking | |
| is_non_thinking = agent_role in NON_THOUGHT_ROLES | |
| # Don't compute groups for thinking-mode agents (they bypass CLA) | |
| if not is_non_thinking: | |
| return [] | |
| groups = [] | |
| cfg = self._config | |
| # Start from min_layer, go up to max_layer, step by group_size | |
| for start in range(cfg.min_layer, min(cfg.max_layer, model_layer_count), cfg.group_size): | |
| end = min(start + cfg.group_size - 1, model_layer_count - 1) | |
| if cfg.sharing_direction == "upper": | |
| # NAACL 2025: queries of layer i share KV of layer i+1 (upper layer) | |
| shared_kv_layer = end | |
| else: | |
| # Alternative: share KV of lower layer | |
| shared_kv_layer = start | |
| groups.append((start, shared_kv_layer)) | |
| return groups | |
| def emit_hint( | |
| self, | |
| agent_id: str, | |
| model_id: str, | |
| is_thinking_mode: bool, | |
| model_layer_count: int = 64, | |
| agent_role: str = "default", | |
| ) -> CLAHint: | |
| """ | |
| Emit a CLAHint for a given agent. | |
| If is_thinking_mode=True and thinking_mode_bypass is True, | |
| returns empty layer_groups and 0.0 vram_reduction. | |
| Args: | |
| agent_id: Unique agent identifier | |
| model_id: Model name (e.g., "Qwen3.6-35B-A22B") | |
| is_thinking_mode: True if agent uses chain-of-thought reasoning | |
| model_layer_count: Number of transformer layers | |
| agent_role: Agent role for CLA eligibility determination | |
| Returns: | |
| CLAHint with layer_groups and estimated VRAM reduction | |
| """ | |
| # Bypass if thinking mode and config says to bypass | |
| if is_thinking_mode and self._config.thinking_mode_bypass: | |
| return CLAHint( | |
| agent_id=agent_id, | |
| model_id=model_id, | |
| layer_groups=[], | |
| estimated_vram_reduction_pct=0.0, | |
| is_thinking_mode=True, | |
| group_config=self._config, | |
| ) | |
| layer_groups = self.compute_layer_groups(model_layer_count, agent_role) | |
| vram_reduction = self.estimated_vram_reduction(layer_groups) | |
| return CLAHint( | |
| agent_id=agent_id, | |
| model_id=model_id, | |
| layer_groups=layer_groups, | |
| estimated_vram_reduction_pct=vram_reduction, | |
| is_thinking_mode=is_thinking_mode, | |
| group_config=self._config, | |
| ) | |
| def estimated_vram_reduction(self, layer_groups: list) -> float: | |
| """ | |
| Estimate VRAM reduction factor from layer groups. | |
| group_size=2 → 50% of layers share KV → ~0.5 * KV_per_layer savings. | |
| Conservative estimate since actual savings depend on attention head count. | |
| Args: | |
| layer_groups: Output of compute_layer_groups() | |
| Returns: | |
| Float 0.0–0.5 representing VRAM fraction saved | |
| """ | |
| if not layer_groups: | |
| return 0.0 | |
| # Each group shares 1 layer's KV across group_size layers | |
| # Fraction saved = (group_size - 1) / group_size | |
| # For group_size=2: (2-1)/2 = 0.5 (50% savings) | |
| cfg = self._config | |
| return (cfg.group_size - 1) / cfg.group_size |