Pablo
feat: APOHARA: Context Forge V5 — synthesis + rebrand complete
cf0a8ed
"""CLA Metadata Layer — Cross-Layer KV Cache Sharing hints for vLLM.
Based on:
- CLA (NeurIPS 2024): 2x KV cache reduction by sharing KVs between
adjacent layer groups with negligible accuracy loss.
- NAACL 2025 systematic study: pairing queries of ALL layers with KVs of
UPPER layers outperforms bottom-layer sharing at aggressive compression.
- LCKV (ACL 2024): Layer-Condensed KV, queries of all layers share KVs of
only the top layer.
V4.0 CHANGES: New module for inference-time CLA hint injection.
"""
from dataclasses import dataclass
from typing import Optional
# Non-thinking roles (no chain-of-thought, can benefit from CLA)
NON_THOUGHT_ROLES = frozenset({"retriever", "summarizer", "formatter", "reviewer", "classifier"})
@dataclass
class CLAGroupConfig:
"""Configuration for CLA layer grouping strategy."""
group_size: int = 2 # layers per group (2 = 2x reduction)
sharing_direction: str = "upper" # "upper" | "lower" per NAACL 2025
thinking_mode_bypass: bool = True # never apply CLA in thinking mode
min_layer: int = 0 # skip bottom N layers (attention sinks)
max_layer: int = 64 # skip above this layer index
@dataclass
class CLAHint:
"""Metadata hint for vLLM attention backend to share KV across layers."""
agent_id: str
model_id: str
layer_groups: list[tuple[int, int]] # (start_layer, shared_kv_layer)
estimated_vram_reduction_pct: float # 0.0–0.5 for group_size=2
is_thinking_mode: bool # if True, hint is IGNORED by backend
group_config: CLAGroupConfig
class CLAMetadataLayer:
"""
Computes CLA metadata hints for agents based on their role and mode.
Usage:
cla = CLAMetadataLayer(CLAGroupConfig(group_size=2))
hint = cla.emit_hint("agent1", "Qwen3.6-35B-A22B", is_thinking_mode=False, agent_role="retriever")
"""
def __init__(self, config: CLAGroupConfig = CLAGroupConfig()):
self._config = config
def compute_layer_groups(
self,
model_layer_count: int,
agent_role: str,
) -> list[tuple[int, int]]:
"""
Compute layer sharing groups per NAACL 2025 'upper-layer' strategy.
For group_size=2 and 64 layers:
[(0,1), (2,3), (4,5), ..., (62,63)]
→ layer 0 queries use KV of layer 1, etc.
Skip min_layer bottom layers to protect attention sinks.
Args:
model_layer_count: Total number of transformer layers in model
agent_role: Agent role (e.g., "retriever", "summarizer") determines
whether this agent is in thinking or non-thinking mode
Returns:
List of (start_layer, shared_kv_layer) tuples
"""
# Check if role is thinking or non-thinking
is_non_thinking = agent_role in NON_THOUGHT_ROLES
# Don't compute groups for thinking-mode agents (they bypass CLA)
if not is_non_thinking:
return []
groups = []
cfg = self._config
# Start from min_layer, go up to max_layer, step by group_size
for start in range(cfg.min_layer, min(cfg.max_layer, model_layer_count), cfg.group_size):
end = min(start + cfg.group_size - 1, model_layer_count - 1)
if cfg.sharing_direction == "upper":
# NAACL 2025: queries of layer i share KV of layer i+1 (upper layer)
shared_kv_layer = end
else:
# Alternative: share KV of lower layer
shared_kv_layer = start
groups.append((start, shared_kv_layer))
return groups
def emit_hint(
self,
agent_id: str,
model_id: str,
is_thinking_mode: bool,
model_layer_count: int = 64,
agent_role: str = "default",
) -> CLAHint:
"""
Emit a CLAHint for a given agent.
If is_thinking_mode=True and thinking_mode_bypass is True,
returns empty layer_groups and 0.0 vram_reduction.
Args:
agent_id: Unique agent identifier
model_id: Model name (e.g., "Qwen3.6-35B-A22B")
is_thinking_mode: True if agent uses chain-of-thought reasoning
model_layer_count: Number of transformer layers
agent_role: Agent role for CLA eligibility determination
Returns:
CLAHint with layer_groups and estimated VRAM reduction
"""
# Bypass if thinking mode and config says to bypass
if is_thinking_mode and self._config.thinking_mode_bypass:
return CLAHint(
agent_id=agent_id,
model_id=model_id,
layer_groups=[],
estimated_vram_reduction_pct=0.0,
is_thinking_mode=True,
group_config=self._config,
)
layer_groups = self.compute_layer_groups(model_layer_count, agent_role)
vram_reduction = self.estimated_vram_reduction(layer_groups)
return CLAHint(
agent_id=agent_id,
model_id=model_id,
layer_groups=layer_groups,
estimated_vram_reduction_pct=vram_reduction,
is_thinking_mode=is_thinking_mode,
group_config=self._config,
)
def estimated_vram_reduction(self, layer_groups: list) -> float:
"""
Estimate VRAM reduction factor from layer groups.
group_size=2 → 50% of layers share KV → ~0.5 * KV_per_layer savings.
Conservative estimate since actual savings depend on attention head count.
Args:
layer_groups: Output of compute_layer_groups()
Returns:
Float 0.0–0.5 representing VRAM fraction saved
"""
if not layer_groups:
return 0.0
# Each group shares 1 layer's KV across group_size layers
# Fraction saved = (group_size - 1) / group_size
# For group_size=2: (2-1)/2 = 0.5 (50% savings)
cfg = self._config
return (cfg.group_size - 1) / cfg.group_size