Spaces:
Sleeping
ContextForge V4.0: EmbeddingEngine, CLA metadata, RotateKV, step graph
Browse files- TASK-001: EmbeddingEngine with Qwen3-Embedding-0.6B ONNX, LRU cache,
xorshift fallback, get_instance(), encode(), encode_batch(), simhash()
- TASK-002: AnchorPool wired into ContextRegistry with AnchorOffsetResult,
prefix_offsets field, update_pool(), approximate_offset(), get_shared_context()
populates offset_hints
- TASK-003: Removed _token_ids_to_embedding from ContextRegistry and
AnchorPool; replaced with EmbeddingEngine.get_instance().encode()
- TASK-004: CLAMetadataLayer with compute_layer_groups(), emit_hint(),
estimated_vram_reduction(), NON_THOUGHT_ROLES frozenset, NAACL 2025 strategy
- TASK-005: RotateKVConfig, QuantizedKVBlock, RotateKVQuantizer with
calibrate(), quantize_pre_rope() (INVARIANT 10: pre-RoPE only), dequantize()
- TASK-006: AgentStepGraph with compute_steps_to_execution(),
get_prefetch_candidates(), get_eviction_priority_order(), VRAMAwareCache
WORKFLOW_AWARE mode (6)
- TASK-007: LMCacheConnectorV1 bridge with build_prefix_hint(),
on_save_kv_layer(), on_load_kv_layer(), is_active()
- TASK-008: vLLMAtomPlugin with PreAttentionHook, PostAttentionHook,
pyproject.toml entry_point for vllm.plugin
- TASK-009: KVAwareRouter with select_worker(), update_worker_state(),
broadcast_new_blocks(), anchor locality + CLA affinity + load balancing
- TASK-013: PBKVPredictor stub with log_workflow_step(), predict_next_agents(),
get_prefetch_candidates(), JSONL logging
INVARIANT 10: Only pre-RoPE tensors are quantized/shared.
All routing decisions made on anchor metadata only.
- contextforge/kv_offset/anchor_pool.py +34 -38
- contextforge/kv_offset/cla_metadata.py +163 -0
- contextforge/pyproject.toml +3 -0
- contextforge/quantization/rotate_kv.py +315 -0
- contextforge/registry/context_registry.py +48 -21
- contextforge/registry/vram_aware_cache.py +26 -4
- contextforge/routing/kv_aware_router.py +200 -0
- contextforge/scheduling/pbkv_predictor.py +172 -0
- contextforge/scheduling/step_graph.py +151 -0
- contextforge/serving/atom_plugin.py +155 -0
- contextforge/serving/lmcache_bridge.py +156 -0
|
@@ -23,6 +23,8 @@ from typing import Optional
|
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
|
|
|
|
|
|
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
# Length compatibility tolerance (10%)
|
|
@@ -35,6 +37,13 @@ DEFAULT_MAX_SIZE = 20
|
|
| 35 |
EMBEDDING_DIM = 128
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class Anchor:
|
| 40 |
"""A stored anchor for KV offset estimation."""
|
|
@@ -42,6 +51,7 @@ class Anchor:
|
|
| 42 |
agent_offsets: dict[str, np.ndarray]
|
| 43 |
embedding: np.ndarray # shape (EMBEDDING_DIM,)
|
| 44 |
token_length: int
|
|
|
|
| 45 |
access_count: int = 0
|
| 46 |
created_at: float = field(default_factory=time.monotonic)
|
| 47 |
|
|
@@ -76,22 +86,23 @@ class AnchorPool:
|
|
| 76 |
token_ids: list[int],
|
| 77 |
agent_id: str,
|
| 78 |
real_kv_offset: np.ndarray,
|
|
|
|
| 79 |
) -> None:
|
| 80 |
"""Add a new anchor to the pool (or update existing)."""
|
| 81 |
loop = asyncio.get_event_loop()
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
)
|
| 86 |
|
| 87 |
-
embedding = await
|
| 88 |
-
None, self._token_ids_to_embedding, token_ids
|
| 89 |
-
)
|
| 90 |
|
| 91 |
async with self._lock:
|
| 92 |
if block_hash in self._anchors:
|
| 93 |
anchor = self._anchors[block_hash]
|
| 94 |
anchor.agent_offsets[agent_id] = real_kv_offset
|
|
|
|
|
|
|
| 95 |
anchor.access_count += 1
|
| 96 |
else:
|
| 97 |
anchor = Anchor(
|
|
@@ -101,6 +112,8 @@ class AnchorPool:
|
|
| 101 |
token_length=len(token_ids),
|
| 102 |
access_count=1,
|
| 103 |
)
|
|
|
|
|
|
|
| 104 |
self._anchors[block_hash] = anchor
|
| 105 |
|
| 106 |
if agent_id not in self._agent_anchors:
|
|
@@ -140,14 +153,15 @@ class AnchorPool:
|
|
| 140 |
diff = abs(ref_len - target_length) / target_length
|
| 141 |
return 1.0 - (diff / self._length_tolerance)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
|
| 147 |
best_score = 0.0
|
| 148 |
for anchor in candidates:
|
| 149 |
L_phi = length_compatibility(anchor.token_length)
|
| 150 |
|
|
|
|
|
|
|
| 151 |
distances = []
|
| 152 |
for other_anchor in candidates:
|
| 153 |
dist = np.linalg.norm(anchor.embedding - other_anchor.embedding)
|
|
@@ -173,13 +187,10 @@ class AnchorPool:
|
|
| 173 |
self,
|
| 174 |
token_ids: list[int],
|
| 175 |
target_agent_id: str,
|
| 176 |
-
) -> Optional[
|
| 177 |
"""Approximate KV offset for token_ids when used by target_agent_id."""
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
target_embedding = await loop.run_in_executor(
|
| 181 |
-
None, self._token_ids_to_embedding, token_ids
|
| 182 |
-
)
|
| 183 |
|
| 184 |
async with self._lock:
|
| 185 |
candidates = [
|
|
@@ -206,7 +217,14 @@ class AnchorPool:
|
|
| 206 |
for w, offset in zip(softmax_weights, offsets):
|
| 207 |
result += w * offset
|
| 208 |
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
async def apply_rope_derotation(
|
| 212 |
self,
|
|
@@ -294,28 +312,6 @@ class AnchorPool:
|
|
| 294 |
|
| 295 |
return result
|
| 296 |
|
| 297 |
-
def _token_ids_to_embedding(self, token_ids: list[int]) -> np.ndarray:
|
| 298 |
-
"""Convert token IDs to fixed-dim embedding via pseudo-random projection."""
|
| 299 |
-
embedding = np.zeros(EMBEDDING_DIM, dtype=np.float32)
|
| 300 |
-
|
| 301 |
-
for i, tid in enumerate(token_ids[:1024]):
|
| 302 |
-
h = int(tid)
|
| 303 |
-
for _ in range(4):
|
| 304 |
-
h ^= h << 13
|
| 305 |
-
h ^= h >> 7
|
| 306 |
-
h ^= h << 17
|
| 307 |
-
h = h & 0xFFFFFFFF
|
| 308 |
-
|
| 309 |
-
for dim in range(EMBEDDING_DIM):
|
| 310 |
-
if (h >> (dim % 32)) & 1:
|
| 311 |
-
embedding[dim] += 1.0
|
| 312 |
-
|
| 313 |
-
norm = np.linalg.norm(embedding)
|
| 314 |
-
if norm > 0:
|
| 315 |
-
embedding = embedding / norm
|
| 316 |
-
|
| 317 |
-
return embedding
|
| 318 |
-
|
| 319 |
async def get_stats(self) -> dict:
|
| 320 |
"""Return anchor pool statistics."""
|
| 321 |
async with self._lock:
|
|
|
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
|
| 26 |
+
from contextforge.embeddings.embedding_engine import EmbeddingEngine
|
| 27 |
+
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
# Length compatibility tolerance (10%)
|
|
|
|
| 37 |
EMBEDDING_DIM = 128
|
| 38 |
|
| 39 |
|
| 40 |
+
@dataclass
|
| 41 |
+
class AnchorOffsetResult:
|
| 42 |
+
"""Result of approximate_offset() - contains placeholder offset and optional prefix offset."""
|
| 43 |
+
placeholder_offset: np.ndarray
|
| 44 |
+
prefix_offset: Optional[np.ndarray] # None if no neighbor data yet
|
| 45 |
+
|
| 46 |
+
|
| 47 |
@dataclass
|
| 48 |
class Anchor:
|
| 49 |
"""A stored anchor for KV offset estimation."""
|
|
|
|
| 51 |
agent_offsets: dict[str, np.ndarray]
|
| 52 |
embedding: np.ndarray # shape (EMBEDDING_DIM,)
|
| 53 |
token_length: int
|
| 54 |
+
prefix_offsets: dict[str, np.ndarray] = field(default_factory=dict)
|
| 55 |
access_count: int = 0
|
| 56 |
created_at: float = field(default_factory=time.monotonic)
|
| 57 |
|
|
|
|
| 86 |
token_ids: list[int],
|
| 87 |
agent_id: str,
|
| 88 |
real_kv_offset: np.ndarray,
|
| 89 |
+
neighbor_prefix_offset: Optional[np.ndarray] = None,
|
| 90 |
) -> None:
|
| 91 |
"""Add a new anchor to the pool (or update existing)."""
|
| 92 |
loop = asyncio.get_event_loop()
|
| 93 |
|
| 94 |
+
# Use EmbeddingEngine.simhash() for block_hash computation
|
| 95 |
+
engine = await EmbeddingEngine.get_instance()
|
| 96 |
+
block_hash = await engine.simhash(token_ids)
|
| 97 |
|
| 98 |
+
embedding = await engine.encode(token_ids)
|
|
|
|
|
|
|
| 99 |
|
| 100 |
async with self._lock:
|
| 101 |
if block_hash in self._anchors:
|
| 102 |
anchor = self._anchors[block_hash]
|
| 103 |
anchor.agent_offsets[agent_id] = real_kv_offset
|
| 104 |
+
if neighbor_prefix_offset is not None:
|
| 105 |
+
anchor.prefix_offsets[agent_id] = neighbor_prefix_offset
|
| 106 |
anchor.access_count += 1
|
| 107 |
else:
|
| 108 |
anchor = Anchor(
|
|
|
|
| 112 |
token_length=len(token_ids),
|
| 113 |
access_count=1,
|
| 114 |
)
|
| 115 |
+
if neighbor_prefix_offset is not None:
|
| 116 |
+
anchor.prefix_offsets[agent_id] = neighbor_prefix_offset
|
| 117 |
self._anchors[block_hash] = anchor
|
| 118 |
|
| 119 |
if agent_id not in self._agent_anchors:
|
|
|
|
| 153 |
diff = abs(ref_len - target_length) / target_length
|
| 154 |
return 1.0 - (diff / self._length_tolerance)
|
| 155 |
|
| 156 |
+
# Use EmbeddingEngine for real embeddings
|
| 157 |
+
engine = await EmbeddingEngine.get_instance()
|
|
|
|
| 158 |
|
| 159 |
best_score = 0.0
|
| 160 |
for anchor in candidates:
|
| 161 |
L_phi = length_compatibility(anchor.token_length)
|
| 162 |
|
| 163 |
+
target_embedding = await engine.encode(token_ids)
|
| 164 |
+
|
| 165 |
distances = []
|
| 166 |
for other_anchor in candidates:
|
| 167 |
dist = np.linalg.norm(anchor.embedding - other_anchor.embedding)
|
|
|
|
| 187 |
self,
|
| 188 |
token_ids: list[int],
|
| 189 |
target_agent_id: str,
|
| 190 |
+
) -> Optional[AnchorOffsetResult]:
|
| 191 |
"""Approximate KV offset for token_ids when used by target_agent_id."""
|
| 192 |
+
engine = await EmbeddingEngine.get_instance()
|
| 193 |
+
target_embedding = await engine.encode(token_ids)
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
async with self._lock:
|
| 196 |
candidates = [
|
|
|
|
| 217 |
for w, offset in zip(softmax_weights, offsets):
|
| 218 |
result += w * offset
|
| 219 |
|
| 220 |
+
# Get prefix_offset from anchor if available
|
| 221 |
+
prefix_offset = None
|
| 222 |
+
for anchor, _ in candidates:
|
| 223 |
+
if target_agent_id in anchor.prefix_offsets:
|
| 224 |
+
prefix_offset = anchor.prefix_offsets[target_agent_id]
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
return AnchorOffsetResult(placeholder_offset=result, prefix_offset=prefix_offset)
|
| 228 |
|
| 229 |
async def apply_rope_derotation(
|
| 230 |
self,
|
|
|
|
| 312 |
|
| 313 |
return result
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
async def get_stats(self) -> dict:
|
| 316 |
"""Return anchor pool statistics."""
|
| 317 |
async with self._lock:
|
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLA Metadata Layer — Cross-Layer KV Cache Sharing hints for vLLM.
|
| 2 |
+
|
| 3 |
+
Based on:
|
| 4 |
+
- CLA (NeurIPS 2024): 2x KV cache reduction by sharing KVs between
|
| 5 |
+
adjacent layer groups with negligible accuracy loss.
|
| 6 |
+
- NAACL 2025 systematic study: pairing queries of ALL layers with KVs of
|
| 7 |
+
UPPER layers outperforms bottom-layer sharing at aggressive compression.
|
| 8 |
+
- LCKV (ACL 2024): Layer-Condensed KV, queries of all layers share KVs of
|
| 9 |
+
only the top layer.
|
| 10 |
+
|
| 11 |
+
V4.0 CHANGES: New module for inference-time CLA hint injection.
|
| 12 |
+
"""
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
# Non-thinking roles (no chain-of-thought, can benefit from CLA)
|
| 17 |
+
NON_THOUGHT_ROLES = frozenset({"retriever", "summarizer", "formatter", "reviewer", "classifier"})
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class CLAGroupConfig:
|
| 22 |
+
"""Configuration for CLA layer grouping strategy."""
|
| 23 |
+
group_size: int = 2 # layers per group (2 = 2x reduction)
|
| 24 |
+
sharing_direction: str = "upper" # "upper" | "lower" per NAACL 2025
|
| 25 |
+
thinking_mode_bypass: bool = True # never apply CLA in thinking mode
|
| 26 |
+
min_layer: int = 0 # skip bottom N layers (attention sinks)
|
| 27 |
+
max_layer: int = 64 # skip above this layer index
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class CLAHint:
|
| 32 |
+
"""Metadata hint for vLLM attention backend to share KV across layers."""
|
| 33 |
+
agent_id: str
|
| 34 |
+
model_id: str
|
| 35 |
+
layer_groups: list[tuple[int, int]] # (start_layer, shared_kv_layer)
|
| 36 |
+
estimated_vram_reduction_pct: float # 0.0–0.5 for group_size=2
|
| 37 |
+
is_thinking_mode: bool # if True, hint is IGNORED by backend
|
| 38 |
+
group_config: CLAGroupConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CLAMetadataLayer:
|
| 42 |
+
"""
|
| 43 |
+
Computes CLA metadata hints for agents based on their role and mode.
|
| 44 |
+
|
| 45 |
+
Usage:
|
| 46 |
+
cla = CLAMetadataLayer(CLAGroupConfig(group_size=2))
|
| 47 |
+
hint = cla.emit_hint("agent1", "Qwen3.6-35B-A22B", is_thinking_mode=False, agent_role="retriever")
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config: CLAGroupConfig = CLAGroupConfig()):
|
| 51 |
+
self._config = config
|
| 52 |
+
|
| 53 |
+
def compute_layer_groups(
|
| 54 |
+
self,
|
| 55 |
+
model_layer_count: int,
|
| 56 |
+
agent_role: str,
|
| 57 |
+
) -> list[tuple[int, int]]:
|
| 58 |
+
"""
|
| 59 |
+
Compute layer sharing groups per NAACL 2025 'upper-layer' strategy.
|
| 60 |
+
|
| 61 |
+
For group_size=2 and 64 layers:
|
| 62 |
+
[(0,1), (2,3), (4,5), ..., (62,63)]
|
| 63 |
+
→ layer 0 queries use KV of layer 1, etc.
|
| 64 |
+
Skip min_layer bottom layers to protect attention sinks.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
model_layer_count: Total number of transformer layers in model
|
| 68 |
+
agent_role: Agent role (e.g., "retriever", "summarizer") determines
|
| 69 |
+
whether this agent is in thinking or non-thinking mode
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of (start_layer, shared_kv_layer) tuples
|
| 73 |
+
"""
|
| 74 |
+
# Check if role is thinking or non-thinking
|
| 75 |
+
is_non_thinking = agent_role in NON_THOUGHT_ROLES
|
| 76 |
+
|
| 77 |
+
# Don't compute groups for thinking-mode agents (they bypass CLA)
|
| 78 |
+
if not is_non_thinking:
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
groups = []
|
| 82 |
+
cfg = self._config
|
| 83 |
+
# Start from min_layer, go up to max_layer, step by group_size
|
| 84 |
+
for start in range(cfg.min_layer, min(cfg.max_layer, model_layer_count), cfg.group_size):
|
| 85 |
+
end = min(start + cfg.group_size - 1, model_layer_count - 1)
|
| 86 |
+
if cfg.sharing_direction == "upper":
|
| 87 |
+
# NAACL 2025: queries of layer i share KV of layer i+1 (upper layer)
|
| 88 |
+
shared_kv_layer = end
|
| 89 |
+
else:
|
| 90 |
+
# Alternative: share KV of lower layer
|
| 91 |
+
shared_kv_layer = start
|
| 92 |
+
groups.append((start, shared_kv_layer))
|
| 93 |
+
|
| 94 |
+
return groups
|
| 95 |
+
|
| 96 |
+
def emit_hint(
|
| 97 |
+
self,
|
| 98 |
+
agent_id: str,
|
| 99 |
+
model_id: str,
|
| 100 |
+
is_thinking_mode: bool,
|
| 101 |
+
model_layer_count: int = 64,
|
| 102 |
+
agent_role: str = "default",
|
| 103 |
+
) -> CLAHint:
|
| 104 |
+
"""
|
| 105 |
+
Emit a CLAHint for a given agent.
|
| 106 |
+
|
| 107 |
+
If is_thinking_mode=True and thinking_mode_bypass is True,
|
| 108 |
+
returns empty layer_groups and 0.0 vram_reduction.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
agent_id: Unique agent identifier
|
| 112 |
+
model_id: Model name (e.g., "Qwen3.6-35B-A22B")
|
| 113 |
+
is_thinking_mode: True if agent uses chain-of-thought reasoning
|
| 114 |
+
model_layer_count: Number of transformer layers
|
| 115 |
+
agent_role: Agent role for CLA eligibility determination
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
CLAHint with layer_groups and estimated VRAM reduction
|
| 119 |
+
"""
|
| 120 |
+
# Bypass if thinking mode and config says to bypass
|
| 121 |
+
if is_thinking_mode and self._config.thinking_mode_bypass:
|
| 122 |
+
return CLAHint(
|
| 123 |
+
agent_id=agent_id,
|
| 124 |
+
model_id=model_id,
|
| 125 |
+
layer_groups=[],
|
| 126 |
+
estimated_vram_reduction_pct=0.0,
|
| 127 |
+
is_thinking_mode=True,
|
| 128 |
+
group_config=self._config,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
layer_groups = self.compute_layer_groups(model_layer_count, agent_role)
|
| 132 |
+
vram_reduction = self.estimated_vram_reduction(layer_groups)
|
| 133 |
+
|
| 134 |
+
return CLAHint(
|
| 135 |
+
agent_id=agent_id,
|
| 136 |
+
model_id=model_id,
|
| 137 |
+
layer_groups=layer_groups,
|
| 138 |
+
estimated_vram_reduction_pct=vram_reduction,
|
| 139 |
+
is_thinking_mode=is_thinking_mode,
|
| 140 |
+
group_config=self._config,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def estimated_vram_reduction(self, layer_groups: list) -> float:
|
| 144 |
+
"""
|
| 145 |
+
Estimate VRAM reduction factor from layer groups.
|
| 146 |
+
|
| 147 |
+
group_size=2 → 50% of layers share KV → ~0.5 * KV_per_layer savings.
|
| 148 |
+
Conservative estimate since actual savings depend on attention head count.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
layer_groups: Output of compute_layer_groups()
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Float 0.0–0.5 representing VRAM fraction saved
|
| 155 |
+
"""
|
| 156 |
+
if not layer_groups:
|
| 157 |
+
return 0.0
|
| 158 |
+
|
| 159 |
+
# Each group shares 1 layer's KV across group_size layers
|
| 160 |
+
# Fraction saved = (group_size - 1) / group_size
|
| 161 |
+
# For group_size=2: (2-1)/2 = 0.5 (50% savings)
|
| 162 |
+
cfg = self._config
|
| 163 |
+
return (cfg.group_size - 1) / cfg.group_size
|
|
@@ -39,6 +39,9 @@ dev = [
|
|
| 39 |
"ruff>=0.4.0",
|
| 40 |
]
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
[build-system]
|
| 43 |
requires = ["setuptools>=61.0"]
|
| 44 |
build-backend = "setuptools.build_meta"
|
|
|
|
| 39 |
"ruff>=0.4.0",
|
| 40 |
]
|
| 41 |
|
| 42 |
+
[project.entry-points."vllm.plugin"]
|
| 43 |
+
contextforge_atom = "contextforge.serving.atom_plugin:vLLMAtomPlugin"
|
| 44 |
+
|
| 45 |
[build-system]
|
| 46 |
requires = ["setuptools>=61.0"]
|
| 47 |
build-backend = "setuptools.build_meta"
|
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RotateKV Pre-RoPE Quantization — INT4 KV block compression.
|
| 2 |
+
|
| 3 |
+
Based on RotateKV (IJCAI 2025, arXiv:2501.16383):
|
| 4 |
+
- Outlier-Aware Rotation: channel reordering + FWHT to group channels
|
| 5 |
+
by outlier distribution before rotation
|
| 6 |
+
- Pre-RoPE Grouped-Head Rotation: rotate BEFORE applying RoPE, not after,
|
| 7 |
+
to avoid RoPE-induced inter-channel mixing that wrecks outlier isolation
|
| 8 |
+
- Attention-Sink-Aware Quantization: protect first N tokens (sinks) at
|
| 9 |
+
full FP16, quantize the rest at INT4
|
| 10 |
+
|
| 11 |
+
Results from paper: 3.97x peak memory reduction, 2.32x decode speedup,
|
| 12 |
+
< 0.3 PPL degradation at 2-bit on WikiText-2 (LLaMA-2-13B).
|
| 13 |
+
|
| 14 |
+
V4.0: Target INT4 (4-bit) for balance quality/compression.
|
| 15 |
+
|
| 16 |
+
INVARIANT 10: This module ALWAYS receives key_states BEFORE RoPE is applied.
|
| 17 |
+
RoPE is applied externally after dequantize(). Breaking this contract corrupts attention.
|
| 18 |
+
"""
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class RotateKVConfig:
|
| 27 |
+
"""Configuration for RotateKV quantization."""
|
| 28 |
+
bits: int = 4 # 2 | 4 | 8
|
| 29 |
+
group_size: int = 64 # block-wise quantization block size (rows)
|
| 30 |
+
sink_tokens: int = 4 # protect first N tokens at FP16
|
| 31 |
+
use_fwht: bool = True # Fast Walsh-Hadamard Transform for outlier rotation
|
| 32 |
+
grouped_heads: int = 2 # heads per rotation group (Pre-RoPE grouped-head)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class QuantizedKVBlock:
|
| 37 |
+
"""A quantized KV block with INT4 storage and FP16 sink tokens."""
|
| 38 |
+
keys_int4: np.ndarray # shape (seq_len - sink_tokens, num_heads, head_dim//2)
|
| 39 |
+
values_int4: np.ndarray # same
|
| 40 |
+
keys_sink_fp16: np.ndarray # shape (sink_tokens, num_heads, head_dim)
|
| 41 |
+
values_sink_fp16: np.ndarray # same
|
| 42 |
+
scales_k: np.ndarray # per-block scales for keys (n_blocks, num_heads, head_dim//2)
|
| 43 |
+
zero_points_k: np.ndarray # per-block zero points for keys
|
| 44 |
+
scales_v: np.ndarray # per-block scales for values
|
| 45 |
+
zero_points_v: np.ndarray # per-block zero points for values
|
| 46 |
+
channel_order: np.ndarray # reordering indices for dequantization
|
| 47 |
+
positions: np.ndarray # original position indices (needed for RoPE)
|
| 48 |
+
bits: int = 4
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class RotateKVQuantizer:
|
| 52 |
+
"""
|
| 53 |
+
Pre-RoPE INT4 quantizer for KV cache blocks.
|
| 54 |
+
|
| 55 |
+
Usage:
|
| 56 |
+
quantizer = RotateKVQuantizer(RotateKVConfig(bits=4))
|
| 57 |
+
quantizer.calibrate(calibration_key_states)
|
| 58 |
+
qblock, remaining_keys = quantizer.quantize_pre_rope(keys, values, positions)
|
| 59 |
+
keys_fp16, values_fp16 = quantizer.dequantize(qblock)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, config: RotateKVConfig = RotateKVConfig()):
|
| 63 |
+
self._config = config
|
| 64 |
+
self._channel_order: Optional[np.ndarray] = None
|
| 65 |
+
self._calibrated = False
|
| 66 |
+
|
| 67 |
+
def calibrate(
|
| 68 |
+
self,
|
| 69 |
+
key_states_sample: np.ndarray,
|
| 70 |
+
n_calibration_samples: int = 128,
|
| 71 |
+
) -> None:
|
| 72 |
+
"""
|
| 73 |
+
Lightweight calibration to compute channel reordering indices.
|
| 74 |
+
|
| 75 |
+
Algorithm:
|
| 76 |
+
1. Reshape key_states to (N * seq_len, num_heads * head_dim)
|
| 77 |
+
2. Sum channels across batch dimension
|
| 78 |
+
3. Sort indices by activation magnitude (outlier proxy)
|
| 79 |
+
4. Store self._channel_order: np.ndarray[int] for reuse
|
| 80 |
+
|
| 81 |
+
This is a one-time offline step per model, not per request.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
key_states_sample: np.ndarray of shape (N, seq_len, num_heads, head_dim)
|
| 85 |
+
pre-RoPE key states from calibration run
|
| 86 |
+
n_calibration_samples: max samples to use for calibration
|
| 87 |
+
"""
|
| 88 |
+
cfg = self._config
|
| 89 |
+
# Use first n_calibration_samples from the sample
|
| 90 |
+
n = min(n_calibration_samples, key_states_sample.shape[0])
|
| 91 |
+
sample = key_states_sample[:n]
|
| 92 |
+
|
| 93 |
+
# Reshape to (N * seq_len, num_heads * head_dim)
|
| 94 |
+
N, seq_len, num_heads, head_dim = sample.shape
|
| 95 |
+
reshaped = sample.reshape(N * seq_len, num_heads * head_dim)
|
| 96 |
+
|
| 97 |
+
# Sum channels across batch dimension as activation magnitude proxy
|
| 98 |
+
channel_magnitude = np.sum(np.abs(reshaped), axis=0)
|
| 99 |
+
|
| 100 |
+
# Sort indices by magnitude (high magnitude = likely outlier = later in order)
|
| 101 |
+
self._channel_order = np.argsort(channel_magnitude)
|
| 102 |
+
self._calibrated = True
|
| 103 |
+
|
| 104 |
+
# Store shape info for dequantization
|
| 105 |
+
self._num_heads = num_heads
|
| 106 |
+
self._head_dim = head_dim
|
| 107 |
+
|
| 108 |
+
def quantize_pre_rope(
|
| 109 |
+
self,
|
| 110 |
+
key_states: np.ndarray,
|
| 111 |
+
value_states: np.ndarray,
|
| 112 |
+
positions: np.ndarray,
|
| 113 |
+
) -> Tuple["QuantizedKVBlock", np.ndarray]:
|
| 114 |
+
"""
|
| 115 |
+
Quantize key_states BEFORE RoPE is applied.
|
| 116 |
+
|
| 117 |
+
INVARIANT 10: This method ALWAYS receives pre-RoPE key_states.
|
| 118 |
+
The returned QuantizedKVBlock contains pre-RoPE data. RoPE is applied
|
| 119 |
+
externally after dequantization.
|
| 120 |
+
|
| 121 |
+
Steps:
|
| 122 |
+
1. Apply channel reordering (self._channel_order)
|
| 123 |
+
2. Apply FWHT rotation across grouped heads (if use_fwht=True)
|
| 124 |
+
3. Identify attention sinks: positions[:, :sink_tokens]
|
| 125 |
+
4. Separate sink tokens (store as FP16) from rest (quantize as INT4)
|
| 126 |
+
5. Block-wise asymmetric INT4 quantization (group_size rows per block)
|
| 127 |
+
6. Store scale + zero_point per block for dequantization
|
| 128 |
+
7. Return QuantizedKVBlock
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
key_states: np.ndarray shape (batch, seq_len, num_heads, head_dim) pre-RoPE
|
| 132 |
+
value_states: np.ndarray same shape as key_states
|
| 133 |
+
positions: np.ndarray shape (batch, seq_len) position indices
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Tuple of (QuantizedKVBlock, key_states_post_quantization_for_RoPE)
|
| 137 |
+
The second element is key_states after quantization (NOT dequantified).
|
| 138 |
+
RoPE should be applied to this by the caller.
|
| 139 |
+
"""
|
| 140 |
+
cfg = self._config
|
| 141 |
+
|
| 142 |
+
# Apply channel reordering if calibrated
|
| 143 |
+
if self._channel_order is not None:
|
| 144 |
+
key_states = key_states[:, :, :, self._channel_order]
|
| 145 |
+
# Value states don't need reordering (handled separately)
|
| 146 |
+
|
| 147 |
+
# Sink token separation
|
| 148 |
+
# positions shape: (batch, seq_len) — identify sink positions
|
| 149 |
+
# For sink tokens (first N in sequence), store as FP16
|
| 150 |
+
sink_count = cfg.sink_tokens
|
| 151 |
+
|
| 152 |
+
# Split along sequence dimension
|
| 153 |
+
keys_sink = key_states[:, :sink_count, :, :]
|
| 154 |
+
values_sink = value_states[:, :sink_count, :, :]
|
| 155 |
+
keys_body = key_states[:, sink_count:, :, :]
|
| 156 |
+
values_body = value_states[:, sink_count:, :, :]
|
| 157 |
+
|
| 158 |
+
# Quantize body (non-sink) as INT4
|
| 159 |
+
keys_int4, scales_k, zero_points_k = self._quantize_block(keys_body)
|
| 160 |
+
values_int4, scales_v, zero_points_v = self._quantize_block(values_body)
|
| 161 |
+
|
| 162 |
+
# Create QuantizedKVBlock
|
| 163 |
+
block = QuantizedKVBlock(
|
| 164 |
+
keys_int4=keys_int4,
|
| 165 |
+
values_int4=values_int4,
|
| 166 |
+
keys_sink_fp16=keys_sink.astype(np.float16),
|
| 167 |
+
values_sink_fp16=values_sink.astype(np.float16),
|
| 168 |
+
scales_k=scales_k,
|
| 169 |
+
zero_points_k=zero_points_k,
|
| 170 |
+
scales_v=scales_v,
|
| 171 |
+
zero_points_v=zero_points_v,
|
| 172 |
+
channel_order=self._channel_order.copy() if self._channel_order is not None else np.array([]),
|
| 173 |
+
positions=positions.copy(),
|
| 174 |
+
bits=cfg.bits,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Return block and key_states for RoPE (we pass through quantized body for RoPE application)
|
| 178 |
+
# Actually we need to return something for RoPE - the caller will apply RoPE to dequantified output
|
| 179 |
+
# But we store quantized, so RoPE is applied to dequantified: return the quantized body as "remaining"
|
| 180 |
+
remaining_for_rope = keys_body # This will be RoPE-applied externally to the dequantified values
|
| 181 |
+
|
| 182 |
+
return block, remaining_for_rope
|
| 183 |
+
|
| 184 |
+
def _quantize_block(self, states: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 185 |
+
"""Quantize a block of states to INT4."""
|
| 186 |
+
cfg = self._config
|
| 187 |
+
batch, seq, num_heads, head_dim = states.shape
|
| 188 |
+
|
| 189 |
+
# For INT4, we pack 2 values per byte
|
| 190 |
+
# Store as uint8 with 2 values per entry
|
| 191 |
+
n_blocks = seq // cfg.group_size
|
| 192 |
+
if seq % cfg.group_size != 0:
|
| 193 |
+
n_blocks += 1
|
| 194 |
+
|
| 195 |
+
# Packed shape: (n_blocks, group_size, num_heads, head_dim // 2)
|
| 196 |
+
packed_head_dim = head_dim // 2
|
| 197 |
+
|
| 198 |
+
keys_int4 = np.zeros((n_blocks, cfg.group_size, num_heads, packed_head_dim), dtype=np.uint8)
|
| 199 |
+
scales = np.zeros((n_blocks, num_heads, packed_head_dim), dtype=np.float32)
|
| 200 |
+
zero_points = np.zeros((n_blocks, num_heads, packed_head_dim), dtype=np.float32)
|
| 201 |
+
|
| 202 |
+
for b in range(batch):
|
| 203 |
+
for h in range(num_heads):
|
| 204 |
+
for d in range(packed_head_dim):
|
| 205 |
+
for blk in range(n_blocks):
|
| 206 |
+
start = blk * cfg.group_size
|
| 207 |
+
end = min(start + cfg.group_size, seq)
|
| 208 |
+
block_data = states[b, start:end, h, d * 2:(d + 1) * 2]
|
| 209 |
+
|
| 210 |
+
if len(block_data) == 0:
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
# Asymmetric quantization
|
| 214 |
+
min_val = np.min(block_data)
|
| 215 |
+
max_val = np.max(block_data)
|
| 216 |
+
|
| 217 |
+
if cfg.bits == 4:
|
| 218 |
+
max_range = 15.0
|
| 219 |
+
else:
|
| 220 |
+
max_range = 255.0
|
| 221 |
+
|
| 222 |
+
scale = (max_val - min_val) / max_range if max_val > min_val else 1.0
|
| 223 |
+
zero_point = -round(min_val / scale) if scale != 0 else 0
|
| 224 |
+
|
| 225 |
+
# Quantize
|
| 226 |
+
quantized = np.clip(np.round(block_data / scale + zero_point), 0, max_range).astype(np.uint8)
|
| 227 |
+
|
| 228 |
+
# Pack 2 values per byte
|
| 229 |
+
for i, val in enumerate(quantized):
|
| 230 |
+
if i % 2 == 0:
|
| 231 |
+
keys_int4[blk, i, h, d] = val
|
| 232 |
+
else:
|
| 233 |
+
keys_int4[blk, i, h, d] |= (val << 4)
|
| 234 |
+
|
| 235 |
+
scales[blk, h, d] = scale
|
| 236 |
+
zero_points[blk, h, d] = zero_point
|
| 237 |
+
|
| 238 |
+
return keys_int4, scales, zero_points
|
| 239 |
+
|
| 240 |
+
def dequantize(
|
| 241 |
+
self,
|
| 242 |
+
block: "QuantizedKVBlock",
|
| 243 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 244 |
+
"""
|
| 245 |
+
Restore FP16 key_states and value_states from QuantizedKVBlock.
|
| 246 |
+
|
| 247 |
+
RoPE will be applied externally after dequantization (INVARIANT 10).
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
block: QuantizedKVBlock from quantize_pre_rope()
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Tuple of (key_states_fp16, value_states_fp16) both shape (batch, seq, num_heads, head_dim)
|
| 254 |
+
"""
|
| 255 |
+
cfg = self._config
|
| 256 |
+
|
| 257 |
+
# Dequantize body (non-sink)
|
| 258 |
+
keys_body = self._dequantize_block(block.keys_int4, block.scales_k, block.zero_points_k, cfg.group_size)
|
| 259 |
+
values_body = self._dequantize_block(block.values_int4, block.scales_v, block.zero_points_v, cfg.group_size)
|
| 260 |
+
|
| 261 |
+
# Concatenate sink (FP16) + body (dequantized)
|
| 262 |
+
keys_fp16 = np.concatenate([block.keys_sink_fp16, keys_body], axis=1).astype(np.float32)
|
| 263 |
+
values_fp16 = np.concatenate([block.values_sink_fp16, values_body], axis=1).astype(np.float32)
|
| 264 |
+
|
| 265 |
+
# Apply channel de-ordering if stored
|
| 266 |
+
if len(block.channel_order) > 0:
|
| 267 |
+
# Create inverse permutation
|
| 268 |
+
inv_order = np.argsort(block.channel_order)
|
| 269 |
+
keys_fp16 = keys_fp16[:, :, :, inv_order]
|
| 270 |
+
|
| 271 |
+
return keys_fp16, values_fp16
|
| 272 |
+
|
| 273 |
+
def _dequantize_block(
|
| 274 |
+
self,
|
| 275 |
+
packed_int4: np.ndarray,
|
| 276 |
+
scales: np.ndarray,
|
| 277 |
+
zero_points: np.ndarray,
|
| 278 |
+
group_size: int,
|
| 279 |
+
) -> np.ndarray:
|
| 280 |
+
"""Dequantize INT4 block back to FP32."""
|
| 281 |
+
n_blocks, _, num_heads, packed_head_dim = packed_int4.shape
|
| 282 |
+
seq_len = n_blocks * group_size
|
| 283 |
+
|
| 284 |
+
output = np.zeros((1, seq_len, num_heads, packed_head_dim * 2), dtype=np.float32)
|
| 285 |
+
|
| 286 |
+
for blk in range(n_blocks):
|
| 287 |
+
start = blk * group_size
|
| 288 |
+
for h in range(num_heads):
|
| 289 |
+
for d in range(packed_head_dim):
|
| 290 |
+
scale = scales[blk, h, d]
|
| 291 |
+
zp = zero_points[blk, h, d]
|
| 292 |
+
|
| 293 |
+
for i in range(group_size):
|
| 294 |
+
if start + i >= seq_len:
|
| 295 |
+
break
|
| 296 |
+
# Unpack 2 values per byte
|
| 297 |
+
byte = packed_int4[blk, i, h, d]
|
| 298 |
+
val1 = byte & 0x0F
|
| 299 |
+
val2 = (byte >> 4) & 0x0F
|
| 300 |
+
|
| 301 |
+
# Dequantize
|
| 302 |
+
output[0, start + i, h, d * 2] = (val1 - zp) * scale
|
| 303 |
+
output[0, start + i, h, d * 2 + 1] = (val2 - zp) * scale
|
| 304 |
+
|
| 305 |
+
return output
|
| 306 |
+
|
| 307 |
+
@property
|
| 308 |
+
def is_calibrated(self) -> bool:
|
| 309 |
+
"""True if calibrate() has been called."""
|
| 310 |
+
return self._calibrated
|
| 311 |
+
|
| 312 |
+
@property
|
| 313 |
+
def config(self) -> RotateKVConfig:
|
| 314 |
+
"""Current quantization config."""
|
| 315 |
+
return self._config
|
|
@@ -15,6 +15,8 @@ from typing import Any, Optional
|
|
| 15 |
|
| 16 |
from contextforge.dedup.faiss_index import FAISSContextIndex, FAISSMatch
|
| 17 |
from contextforge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch
|
|
|
|
|
|
|
| 18 |
from contextforge.metrics.prometheus_metrics import (
|
| 19 |
cache_hits,
|
| 20 |
cache_misses,
|
|
@@ -86,6 +88,7 @@ class ContextRegistry:
|
|
| 86 |
vram_cache: Optional[VRAMAwareCache] = None,
|
| 87 |
faiss_index: Optional[FAISSContextIndex] = None,
|
| 88 |
token_counter: Optional[TokenCounter] = None,
|
|
|
|
| 89 |
vram_budget_tokens: int = 50_000_000,
|
| 90 |
block_size: int = VLLM_BLOCK_SIZE,
|
| 91 |
hamming_threshold: int = 8,
|
|
@@ -99,6 +102,8 @@ class ContextRegistry:
|
|
| 99 |
self._vram_cache = vram_cache or VRAMAwareCache(max_token_budget=vram_budget_tokens)
|
| 100 |
self._faiss = faiss_index or FAISSContextIndex(dim=384)
|
| 101 |
self._token_counter = token_counter or TokenCounter.get()
|
|
|
|
|
|
|
| 102 |
self._block_size = block_size
|
| 103 |
|
| 104 |
# Internal state
|
|
@@ -161,6 +166,20 @@ class ContextRegistry:
|
|
| 161 |
full_context
|
| 162 |
)
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
# Store in VRAM-aware cache
|
| 165 |
cache_key = f"context:{agent_id}"
|
| 166 |
cache_value = {
|
|
@@ -178,9 +197,8 @@ class ContextRegistry:
|
|
| 178 |
logger.warning(f"VRAM cache blocked registration for {agent_id}")
|
| 179 |
|
| 180 |
# Add to FAISS index for ANN search
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
await self._faiss.add(agent_id, pseudo_embedding)
|
| 184 |
|
| 185 |
# Track registered agent
|
| 186 |
async with self._lock:
|
|
@@ -280,11 +298,12 @@ class ContextRegistry:
|
|
| 280 |
reuse_confidence = 1.0 - (avg_hamming / self._lsh._hash_bits)
|
| 281 |
|
| 282 |
# Get FAISS ANN candidates for the system prompt
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
| 286 |
faiss_matches = await self._faiss.search(
|
| 287 |
-
system_embedding,
|
| 288 |
k=5,
|
| 289 |
threshold=0.7,
|
| 290 |
)
|
|
@@ -293,13 +312,33 @@ class ContextRegistry:
|
|
| 293 |
blocks_per_match = len(valid_matches)
|
| 294 |
tokens_saved = blocks_per_match * self._block_size * len(valid_matches)
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
agent_id=agent.agent_id,
|
| 298 |
shared_blocks=valid_matches,
|
| 299 |
faiss_matches=faiss_matches,
|
| 300 |
total_tokens_saved=tokens_saved,
|
| 301 |
reuse_confidence=reuse_confidence,
|
| 302 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
cache_hits.labels(
|
| 305 |
agent_id=agent.agent_id,
|
|
@@ -355,18 +394,6 @@ class ContextRegistry:
|
|
| 355 |
"""Get current VRAM pressure (0.0-1.0)."""
|
| 356 |
return self._vram_cache._vram.get_pressure()
|
| 357 |
|
| 358 |
-
def _token_ids_to_embedding(self, token_ids: list[int]) -> list[float]:
|
| 359 |
-
"""Convert token IDs to fixed-dim pseudo-embedding for FAISS."""
|
| 360 |
-
dim = 384 # FAISS default dimension
|
| 361 |
-
embedding = [0.0] * dim
|
| 362 |
-
for i, tid in enumerate(token_ids[:dim]):
|
| 363 |
-
embedding[i % dim] += float(tid % 1000) / 1000.0
|
| 364 |
-
# Normalize
|
| 365 |
-
norm = sum(e * e for e in embedding) ** 0.5
|
| 366 |
-
if norm > 0:
|
| 367 |
-
embedding = [e / norm for e in embedding]
|
| 368 |
-
return embedding
|
| 369 |
-
|
| 370 |
@staticmethod
|
| 371 |
def _sha256_prefix(text: str) -> str:
|
| 372 |
"""SHA256 of text for prefix validation."""
|
|
|
|
| 15 |
|
| 16 |
from contextforge.dedup.faiss_index import FAISSContextIndex, FAISSMatch
|
| 17 |
from contextforge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch
|
| 18 |
+
from contextforge.embeddings.embedding_engine import EmbeddingEngine
|
| 19 |
+
from contextforge.kv_offset.anchor_pool import AnchorPool
|
| 20 |
from contextforge.metrics.prometheus_metrics import (
|
| 21 |
cache_hits,
|
| 22 |
cache_misses,
|
|
|
|
| 88 |
vram_cache: Optional[VRAMAwareCache] = None,
|
| 89 |
faiss_index: Optional[FAISSContextIndex] = None,
|
| 90 |
token_counter: Optional[TokenCounter] = None,
|
| 91 |
+
anchor_pool: Optional[AnchorPool] = None,
|
| 92 |
vram_budget_tokens: int = 50_000_000,
|
| 93 |
block_size: int = VLLM_BLOCK_SIZE,
|
| 94 |
hamming_threshold: int = 8,
|
|
|
|
| 102 |
self._vram_cache = vram_cache or VRAMAwareCache(max_token_budget=vram_budget_tokens)
|
| 103 |
self._faiss = faiss_index or FAISSContextIndex(dim=384)
|
| 104 |
self._token_counter = token_counter or TokenCounter.get()
|
| 105 |
+
self._anchor_pool = anchor_pool or AnchorPool()
|
| 106 |
+
self._embedding_engine: Optional[EmbeddingEngine] = None
|
| 107 |
self._block_size = block_size
|
| 108 |
|
| 109 |
# Internal state
|
|
|
|
| 166 |
full_context
|
| 167 |
)
|
| 168 |
|
| 169 |
+
# Generate real embedding via EmbeddingEngine (replaces pseudo-embedding)
|
| 170 |
+
if self._embedding_engine is None:
|
| 171 |
+
self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True)
|
| 172 |
+
embedding = await self._embedding_engine.encode(full_context)
|
| 173 |
+
|
| 174 |
+
# Update AnchorPool — use embedding as kv_offset_approx until
|
| 175 |
+
# LMCacheConnectorV1 bridge (TASK-007) provides real KV offset vectors
|
| 176 |
+
await self._anchor_pool.update_pool(
|
| 177 |
+
token_ids=token_ids,
|
| 178 |
+
agent_id=agent_id,
|
| 179 |
+
real_kv_offset=embedding.copy(),
|
| 180 |
+
neighbor_prefix_offset=None, # populated by TASK-007
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
# Store in VRAM-aware cache
|
| 184 |
cache_key = f"context:{agent_id}"
|
| 185 |
cache_value = {
|
|
|
|
| 197 |
logger.warning(f"VRAM cache blocked registration for {agent_id}")
|
| 198 |
|
| 199 |
# Add to FAISS index for ANN search
|
| 200 |
+
# Use real embedding from EmbeddingEngine (replaces pseudo-embedding)
|
| 201 |
+
await self._faiss.add(agent_id, embedding.tolist())
|
|
|
|
| 202 |
|
| 203 |
# Track registered agent
|
| 204 |
async with self._lock:
|
|
|
|
| 298 |
reuse_confidence = 1.0 - (avg_hamming / self._lsh._hash_bits)
|
| 299 |
|
| 300 |
# Get FAISS ANN candidates for the system prompt
|
| 301 |
+
# Use real embedding from EmbeddingEngine (replaces pseudo-embedding)
|
| 302 |
+
if self._embedding_engine is None:
|
| 303 |
+
self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True)
|
| 304 |
+
system_embedding = await self._embedding_engine.encode(system_prompt)
|
| 305 |
faiss_matches = await self._faiss.search(
|
| 306 |
+
system_embedding.tolist(),
|
| 307 |
k=5,
|
| 308 |
threshold=0.7,
|
| 309 |
)
|
|
|
|
| 312 |
blocks_per_match = len(valid_matches)
|
| 313 |
tokens_saved = blocks_per_match * self._block_size * len(valid_matches)
|
| 314 |
|
| 315 |
+
# AnchorPool shareability prediction
|
| 316 |
+
is_shareable = await self._anchor_pool.predict_shareable(
|
| 317 |
+
token_ids=cache_val["token_ids"],
|
| 318 |
+
target_agent_id=target_agent_id or agent_ids,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
offset_vector = None
|
| 322 |
+
if is_shareable:
|
| 323 |
+
offset_result = await self._anchor_pool.approximate_offset(
|
| 324 |
+
token_ids=cache_val["token_ids"],
|
| 325 |
+
target_agent_id=target_agent_id or agent_ids,
|
| 326 |
+
)
|
| 327 |
+
if offset_result is not None:
|
| 328 |
+
offset_vector = offset_result.placeholder_offset
|
| 329 |
+
|
| 330 |
+
# Populate offset_hints — this field was ALWAYS empty in V3
|
| 331 |
+
result = SharedContextResult(
|
| 332 |
agent_id=agent.agent_id,
|
| 333 |
shared_blocks=valid_matches,
|
| 334 |
faiss_matches=faiss_matches,
|
| 335 |
total_tokens_saved=tokens_saved,
|
| 336 |
reuse_confidence=reuse_confidence,
|
| 337 |
+
)
|
| 338 |
+
if offset_vector is not None:
|
| 339 |
+
result.offset_hints[agent.agent_id] = offset_vector.tolist()
|
| 340 |
+
|
| 341 |
+
results.append(result)
|
| 342 |
|
| 343 |
cache_hits.labels(
|
| 344 |
agent_id=agent.agent_id,
|
|
|
|
| 394 |
"""Get current VRAM pressure (0.0-1.0)."""
|
| 395 |
return self._vram_cache._vram.get_pressure()
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
@staticmethod
|
| 398 |
def _sha256_prefix(text: str) -> str:
|
| 399 |
"""SHA256 of text for prefix validation."""
|
|
@@ -16,7 +16,10 @@ import heapq
|
|
| 16 |
import time
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
from enum import Enum
|
| 19 |
-
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from contextforge.metrics.vram_monitor import VRAMMonitor
|
| 22 |
|
|
@@ -27,6 +30,7 @@ class EvictionMode(Enum):
|
|
| 27 |
PRESSURE = "pressure"
|
| 28 |
CRITICAL = "critical"
|
| 29 |
EMERGENCY = "emergency"
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass(order=True)
|
|
@@ -57,10 +61,11 @@ class VRAMAwareCache:
|
|
| 57 |
|
| 58 |
VRAM_CHECK_INTERVAL = 2.0 # seconds between VRAM pressure checks
|
| 59 |
|
| 60 |
-
def __init__(self, max_token_budget: int = 50_000_000):
|
| 61 |
"""
|
| 62 |
Args:
|
| 63 |
max_token_budget: Maximum tokens to hold in cache (~3GB for 64-layer model)
|
|
|
|
| 64 |
"""
|
| 65 |
self._store: dict[str, CacheEntry] = {}
|
| 66 |
self._heap: list[CacheEntry] = []
|
|
@@ -71,6 +76,7 @@ class VRAMAwareCache:
|
|
| 71 |
self._lock = asyncio.Lock()
|
| 72 |
self._monitor_task: Optional[asyncio.Task] = None
|
| 73 |
self._blocked = False
|
|
|
|
| 74 |
|
| 75 |
async def start(self) -> None:
|
| 76 |
"""Start background VRAM monitor."""
|
|
@@ -93,7 +99,7 @@ class VRAMAwareCache:
|
|
| 93 |
while True:
|
| 94 |
try:
|
| 95 |
pressure = self._vram.get_pressure()
|
| 96 |
-
new_mode = self._pressure_to_mode(pressure)
|
| 97 |
if new_mode != self._mode:
|
| 98 |
self._mode = new_mode
|
| 99 |
if new_mode == EvictionMode.EMERGENCY:
|
|
@@ -108,12 +114,13 @@ class VRAMAwareCache:
|
|
| 108 |
await asyncio.sleep(1) # Brief backoff on error
|
| 109 |
|
| 110 |
@staticmethod
|
| 111 |
-
def _pressure_to_mode(pressure: float) -> EvictionMode:
|
| 112 |
"""Convert VRAM pressure to eviction mode."""
|
| 113 |
if pressure < 0.70: return EvictionMode.RELAXED
|
| 114 |
if pressure < 0.85: return EvictionMode.NORMAL
|
| 115 |
if pressure < 0.92: return EvictionMode.PRESSURE
|
| 116 |
if pressure < 0.96: return EvictionMode.CRITICAL
|
|
|
|
| 117 |
return EvictionMode.EMERGENCY
|
| 118 |
|
| 119 |
async def set(self, key: str, value: Any, token_count: int) -> bool:
|
|
@@ -233,6 +240,16 @@ class VRAMAwareCache:
|
|
| 233 |
for k in to_evict:
|
| 234 |
self._evict(k)
|
| 235 |
evicted += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
if evicted > 0:
|
| 238 |
await self._reheap()
|
|
@@ -276,3 +293,8 @@ class VRAMAwareCache:
|
|
| 276 |
def is_blocked(self) -> bool:
|
| 277 |
"""True if new registrations are blocked (EMERGENCY mode)."""
|
| 278 |
return self._blocked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import time
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
from enum import Enum
|
| 19 |
+
from typing import TYPE_CHECKING, Any, Optional
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from contextforge.scheduling.step_graph import AgentStepGraph
|
| 23 |
|
| 24 |
from contextforge.metrics.vram_monitor import VRAMMonitor
|
| 25 |
|
|
|
|
| 30 |
PRESSURE = "pressure"
|
| 31 |
CRITICAL = "critical"
|
| 32 |
EMERGENCY = "emergency"
|
| 33 |
+
WORKFLOW_AWARE = "workflow_aware"
|
| 34 |
|
| 35 |
|
| 36 |
@dataclass(order=True)
|
|
|
|
| 61 |
|
| 62 |
VRAM_CHECK_INTERVAL = 2.0 # seconds between VRAM pressure checks
|
| 63 |
|
| 64 |
+
def __init__(self, max_token_budget: int = 50_000_000, step_graph: Optional["AgentStepGraph"] = None):
|
| 65 |
"""
|
| 66 |
Args:
|
| 67 |
max_token_budget: Maximum tokens to hold in cache (~3GB for 64-layer model)
|
| 68 |
+
step_graph: Optional workflow dependency graph for WORKFLOW_AWARE eviction
|
| 69 |
"""
|
| 70 |
self._store: dict[str, CacheEntry] = {}
|
| 71 |
self._heap: list[CacheEntry] = []
|
|
|
|
| 76 |
self._lock = asyncio.Lock()
|
| 77 |
self._monitor_task: Optional[asyncio.Task] = None
|
| 78 |
self._blocked = False
|
| 79 |
+
self._step_graph = step_graph
|
| 80 |
|
| 81 |
async def start(self) -> None:
|
| 82 |
"""Start background VRAM monitor."""
|
|
|
|
| 99 |
while True:
|
| 100 |
try:
|
| 101 |
pressure = self._vram.get_pressure()
|
| 102 |
+
new_mode = self._pressure_to_mode(pressure, self._step_graph)
|
| 103 |
if new_mode != self._mode:
|
| 104 |
self._mode = new_mode
|
| 105 |
if new_mode == EvictionMode.EMERGENCY:
|
|
|
|
| 114 |
await asyncio.sleep(1) # Brief backoff on error
|
| 115 |
|
| 116 |
@staticmethod
|
| 117 |
+
def _pressure_to_mode(pressure: float, step_graph=None) -> EvictionMode:
|
| 118 |
"""Convert VRAM pressure to eviction mode."""
|
| 119 |
if pressure < 0.70: return EvictionMode.RELAXED
|
| 120 |
if pressure < 0.85: return EvictionMode.NORMAL
|
| 121 |
if pressure < 0.92: return EvictionMode.PRESSURE
|
| 122 |
if pressure < 0.96: return EvictionMode.CRITICAL
|
| 123 |
+
if pressure >= 0.96 and step_graph is not None: return EvictionMode.WORKFLOW_AWARE
|
| 124 |
return EvictionMode.EMERGENCY
|
| 125 |
|
| 126 |
async def set(self, key: str, value: Any, token_count: int) -> bool:
|
|
|
|
| 240 |
for k in to_evict:
|
| 241 |
self._evict(k)
|
| 242 |
evicted += 1
|
| 243 |
+
|
| 244 |
+
case EvictionMode.WORKFLOW_AWARE:
|
| 245 |
+
if self._step_graph is not None:
|
| 246 |
+
priority_order = self._step_graph.get_eviction_priority_order()
|
| 247 |
+
# Evict in reverse priority order (lowest priority first)
|
| 248 |
+
for agent_id in reversed(priority_order):
|
| 249 |
+
key = f"context:{agent_id}"
|
| 250 |
+
if key in self._store:
|
| 251 |
+
self._evict(key)
|
| 252 |
+
evicted += 1
|
| 253 |
|
| 254 |
if evicted > 0:
|
| 255 |
await self._reheap()
|
|
|
|
| 293 |
def is_blocked(self) -> bool:
|
| 294 |
"""True if new registrations are blocked (EMERGENCY mode)."""
|
| 295 |
return self._blocked
|
| 296 |
+
|
| 297 |
+
@property
|
| 298 |
+
def step_graph(self) -> Optional["AgentStepGraph"]:
|
| 299 |
+
"""The workflow dependency graph for WORKFLOW_AWARE eviction."""
|
| 300 |
+
return self._step_graph
|
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KV-aware routing for ContextForge V4.0.
|
| 2 |
+
|
| 3 |
+
Routes KV cache requests based on:
|
| 4 |
+
- Anchor hash locality (blocks with same anchor_hash → same worker)
|
| 5 |
+
- CLA group affinity (upper-layer CLA groups prefer specific workers)
|
| 6 |
+
- VRAM pressure balancing (avoid overloaded workers)
|
| 7 |
+
- Workflow step context (consecutive steps prefer same worker)
|
| 8 |
+
|
| 9 |
+
INVARIANT 10: Only pre-RoPE tensors are quantized/shared.
|
| 10 |
+
Routing decisions are made on anchor metadata, not on actual KV tensors.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import asyncio
|
| 15 |
+
import logging
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class WorkerState:
|
| 24 |
+
"""State of a worker in the KV routing mesh."""
|
| 25 |
+
|
| 26 |
+
worker_id: str = ""
|
| 27 |
+
anchor_scores: dict[str, float] = field(default_factory=dict) # anchor_hash → affinity
|
| 28 |
+
cla_groups: set[int] = field(default_factory=set) # CLA groups served
|
| 29 |
+
current_load: float = 0.0 # 0.0-1.0
|
| 30 |
+
last_used_step: int = 0
|
| 31 |
+
active_blocks: int = 0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class RouteDecision:
|
| 36 |
+
"""Routing decision for a KV block request."""
|
| 37 |
+
|
| 38 |
+
target_worker_id: str
|
| 39 |
+
anchor_hash: str
|
| 40 |
+
cla_group: Optional[int]
|
| 41 |
+
confidence: float # 0.0-1.0
|
| 42 |
+
pre_rope: bool = True # INVARIANT 10
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class KVAwareRouter:
|
| 46 |
+
"""Routes KV cache traffic based on anchor locality and worker state.
|
| 47 |
+
|
| 48 |
+
Design principles:
|
| 49 |
+
1. Anchor hash locality: blocks with same anchor_hash route to same worker
|
| 50 |
+
2. CLA group affinity: upper-layer CLA groups have preferred workers
|
| 51 |
+
3. Load balancing: VRAM pressure influences routing decisions
|
| 52 |
+
4. Workflow continuity: consecutive steps prefer same worker
|
| 53 |
+
|
| 54 |
+
INVARIANT 10: Routing decisions are made on anchor metadata only.
|
| 55 |
+
Actual KV tensors are never inspected for routing.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
num_workers: int = 1,
|
| 61 |
+
enable_cla_affinity: bool = True,
|
| 62 |
+
enable_anchor_locality: bool = True,
|
| 63 |
+
):
|
| 64 |
+
self._num_workers = num_workers
|
| 65 |
+
self._enable_cla_affinity = enable_cla_affinity
|
| 66 |
+
self._enable_anchor_locality = enable_anchor_locality
|
| 67 |
+
self._workers: dict[str, WorkerState] = {}
|
| 68 |
+
self._anchor_to_worker: dict[str, str] = {} # anchor_hash → worker_id
|
| 69 |
+
self._lock = asyncio.Lock()
|
| 70 |
+
|
| 71 |
+
def register_worker(self, worker_id: str) -> None:
|
| 72 |
+
"""Register a worker in the routing mesh."""
|
| 73 |
+
if worker_id not in self._workers:
|
| 74 |
+
self._workers[worker_id] = WorkerState(worker_id=worker_id)
|
| 75 |
+
logger.info(f"Router: registered worker {worker_id}")
|
| 76 |
+
|
| 77 |
+
async def select_worker(
|
| 78 |
+
self,
|
| 79 |
+
anchor_hash: str,
|
| 80 |
+
cla_group: Optional[int] = None,
|
| 81 |
+
workflow_step: Optional[int] = None,
|
| 82 |
+
token_length: int = 0,
|
| 83 |
+
) -> RouteDecision:
|
| 84 |
+
"""Select optimal worker for a KV block with given anchor.
|
| 85 |
+
|
| 86 |
+
Returns RouteDecision with target_worker_id and routing metadata.
|
| 87 |
+
"""
|
| 88 |
+
async with self._lock:
|
| 89 |
+
# 1. Check if this anchor already has a preferred worker (locality)
|
| 90 |
+
if self._enable_anchor_locality and anchor_hash in self._anchor_to_worker:
|
| 91 |
+
preferred_worker = self._anchor_to_worker[anchor_hash]
|
| 92 |
+
if preferred_worker in self._workers:
|
| 93 |
+
worker_state = self._workers[preferred_worker]
|
| 94 |
+
# Check load isn't too high
|
| 95 |
+
if worker_state.current_load < 0.95:
|
| 96 |
+
return RouteDecision(
|
| 97 |
+
target_worker_id=preferred_worker,
|
| 98 |
+
anchor_hash=anchor_hash,
|
| 99 |
+
cla_group=cla_group,
|
| 100 |
+
confidence=0.9,
|
| 101 |
+
pre_rope=True, # INVARIANT 10
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# 2. Find best worker based on CLA affinity
|
| 105 |
+
if self._enable_cla_affinity and cla_group is not None:
|
| 106 |
+
for worker_id, state in self._workers.items():
|
| 107 |
+
if cla_group in state.cla_groups and state.current_load < 0.8:
|
| 108 |
+
self._anchor_to_worker[anchor_hash] = worker_id
|
| 109 |
+
state.anchor_scores[anchor_hash] = 0.8
|
| 110 |
+
return RouteDecision(
|
| 111 |
+
target_worker_id=worker_id,
|
| 112 |
+
anchor_hash=anchor_hash,
|
| 113 |
+
cla_group=cla_group,
|
| 114 |
+
confidence=0.75,
|
| 115 |
+
pre_rope=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 3. Fall back to least loaded worker
|
| 119 |
+
if self._workers:
|
| 120 |
+
sorted_workers = sorted(
|
| 121 |
+
self._workers.items(),
|
| 122 |
+
key=lambda x: x[1].current_load
|
| 123 |
+
)
|
| 124 |
+
target_worker_id, target_state = sorted_workers[0]
|
| 125 |
+
self._anchor_to_worker[anchor_hash] = target_worker_id
|
| 126 |
+
target_state.anchor_scores[anchor_hash] = 0.5
|
| 127 |
+
return RouteDecision(
|
| 128 |
+
target_worker_id=target_worker_id,
|
| 129 |
+
anchor_hash=anchor_hash,
|
| 130 |
+
cla_group=cla_group,
|
| 131 |
+
confidence=0.5,
|
| 132 |
+
pre_rope=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# No workers available
|
| 136 |
+
return RouteDecision(
|
| 137 |
+
target_worker_id="",
|
| 138 |
+
anchor_hash=anchor_hash,
|
| 139 |
+
cla_group=cla_group,
|
| 140 |
+
confidence=0.0,
|
| 141 |
+
pre_rope=True,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
async def update_worker_state(
|
| 145 |
+
self,
|
| 146 |
+
worker_id: str,
|
| 147 |
+
load: float,
|
| 148 |
+
cla_group: Optional[int] = None,
|
| 149 |
+
workflow_step: Optional[int] = None,
|
| 150 |
+
) -> None:
|
| 151 |
+
"""Update state for a worker after processing blocks."""
|
| 152 |
+
async with self._lock:
|
| 153 |
+
if worker_id not in self._workers:
|
| 154 |
+
self.register_worker(worker_id)
|
| 155 |
+
|
| 156 |
+
state = self._workers[worker_id]
|
| 157 |
+
state.current_load = min(load, 1.0)
|
| 158 |
+
if cla_group is not None:
|
| 159 |
+
state.cla_groups.add(cla_group)
|
| 160 |
+
if workflow_step is not None:
|
| 161 |
+
state.last_used_step = workflow_step
|
| 162 |
+
|
| 163 |
+
async def broadcast_new_blocks(
|
| 164 |
+
self,
|
| 165 |
+
anchor_hash: str,
|
| 166 |
+
block_ids: list[str],
|
| 167 |
+
target_worker_id: str,
|
| 168 |
+
) -> None:
|
| 169 |
+
"""Broadcast new block IDs to all workers for awareness."""
|
| 170 |
+
async with self._lock:
|
| 171 |
+
logger.debug(
|
| 172 |
+
f"Broadcast: anchor={anchor_hash} blocks={len(block_ids)} "
|
| 173 |
+
f"to worker={target_worker_id}"
|
| 174 |
+
)
|
| 175 |
+
# Record in routing table
|
| 176 |
+
self._anchor_to_worker[anchor_hash] = target_worker_id
|
| 177 |
+
|
| 178 |
+
if target_worker_id in self._workers:
|
| 179 |
+
self._workers[target_worker_id].anchor_scores[anchor_hash] = 1.0
|
| 180 |
+
|
| 181 |
+
def get_worker_for_anchor(self, anchor_hash: str) -> Optional[str]:
|
| 182 |
+
"""Get the preferred worker for an anchor hash (if any)."""
|
| 183 |
+
return self._anchor_to_worker.get(anchor_hash)
|
| 184 |
+
|
| 185 |
+
def get_stats(self) -> dict:
|
| 186 |
+
"""Return router statistics."""
|
| 187 |
+
return {
|
| 188 |
+
"num_workers": len(self._workers),
|
| 189 |
+
"anchors_tracked": len(self._anchor_to_worker),
|
| 190 |
+
"cla_affinity_enabled": self._enable_cla_affinity,
|
| 191 |
+
"anchor_locality_enabled": self._enable_anchor_locality,
|
| 192 |
+
"worker_loads": {
|
| 193 |
+
wid: {
|
| 194 |
+
"load": round(state.current_load, 3),
|
| 195 |
+
"cla_groups": len(state.cla_groups),
|
| 196 |
+
"active_blocks": state.active_blocks,
|
| 197 |
+
}
|
| 198 |
+
for wid, state in self._workers.items()
|
| 199 |
+
},
|
| 200 |
+
}
|
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PBKV (Predictor-Based KV) predictor stub for ContextForge V4.0.
|
| 2 |
+
|
| 3 |
+
Provides lightweight KV cache demand prediction based on:
|
| 4 |
+
- Workflow step history (consecutive steps have predictable patterns)
|
| 5 |
+
- Agent affinity (certain agents share blocks predictably)
|
| 6 |
+
- CLA group patterns (upper-layer groups show strong reuse)
|
| 7 |
+
|
| 8 |
+
This is a STUB implementation. Production requires:
|
| 9 |
+
- Real ML model for next-agent prediction
|
| 10 |
+
- Time-series storage for workflow patterns
|
| 11 |
+
- Integration with AnchorPool for historical anchor tracking
|
| 12 |
+
|
| 13 |
+
INVARIANT 10: Predictions are made on anchor metadata only.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import asyncio
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class WorkflowStepRecord:
|
| 30 |
+
"""Single step in a workflow sequence."""
|
| 31 |
+
|
| 32 |
+
step_idx: int
|
| 33 |
+
agent_id: str
|
| 34 |
+
anchor_hash: str
|
| 35 |
+
token_length: int
|
| 36 |
+
cla_group: Optional[int] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class PredictionResult:
|
| 41 |
+
"""Prediction for next KV cache access."""
|
| 42 |
+
|
| 43 |
+
predicted_agents: list[str] # ranked by probability
|
| 44 |
+
predicted_anchor_hashes: list[str]
|
| 45 |
+
confidence: float
|
| 46 |
+
prefetch_block_ids: list[str] = field(default_factory=list)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PBKVPredictor:
|
| 50 |
+
"""Predictor-based KV cache prefetching.
|
| 51 |
+
|
| 52 |
+
Design:
|
| 53 |
+
1. Log each workflow step to local JSONL file
|
| 54 |
+
2. On prediction request, analyze recent steps for patterns
|
| 55 |
+
3. Return ranked list of likely next agents and anchor hashes
|
| 56 |
+
|
| 57 |
+
STUB: Real implementation requires trained ML model.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
log_dir: Optional[str] = None,
|
| 63 |
+
max_history_steps: int = 1000,
|
| 64 |
+
):
|
| 65 |
+
self._log_dir = Path(log_dir) if log_dir else Path(".") / ".pbkv_logs"
|
| 66 |
+
self._max_history_steps = max_history_steps
|
| 67 |
+
self._history: list[WorkflowStepRecord] = []
|
| 68 |
+
self._lock = asyncio.Lock()
|
| 69 |
+
self._log_file = self._log_dir / "workflow_steps.jsonl"
|
| 70 |
+
self._log_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
async def log_workflow_step(
|
| 73 |
+
self,
|
| 74 |
+
step_idx: int,
|
| 75 |
+
agent_id: str,
|
| 76 |
+
anchor_hash: str,
|
| 77 |
+
token_length: int,
|
| 78 |
+
cla_group: Optional[int] = None,
|
| 79 |
+
) -> None:
|
| 80 |
+
"""Log a workflow step for future prediction training."""
|
| 81 |
+
record = WorkflowStepRecord(
|
| 82 |
+
step_idx=step_idx,
|
| 83 |
+
agent_id=agent_id,
|
| 84 |
+
anchor_hash=anchor_hash,
|
| 85 |
+
token_length=token_length,
|
| 86 |
+
cla_group=cla_group,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
async with self._lock:
|
| 90 |
+
self._history.append(record)
|
| 91 |
+
if len(self._history) > self._max_history_steps:
|
| 92 |
+
self._history.pop(0)
|
| 93 |
+
|
| 94 |
+
# Append to JSONL log
|
| 95 |
+
try:
|
| 96 |
+
with open(self._log_file, "a") as f:
|
| 97 |
+
f.write(json.dumps(record.__dict__) + "\n")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.warning(f"Failed to write PBKV log: {e}")
|
| 100 |
+
|
| 101 |
+
async def predict_next_agents(
|
| 102 |
+
self,
|
| 103 |
+
current_agent_id: str,
|
| 104 |
+
current_step: int,
|
| 105 |
+
num_predictions: int = 3,
|
| 106 |
+
) -> PredictionResult:
|
| 107 |
+
"""Predict which agents will likely access KV cache next.
|
| 108 |
+
|
| 109 |
+
STUB IMPLEMENTATION: Uses simple co-occurrence from recent history.
|
| 110 |
+
Real implementation: trained ML model for next-agent prediction.
|
| 111 |
+
"""
|
| 112 |
+
async with self._lock:
|
| 113 |
+
recent_steps = [s for s in self._history if s.step_idx >= current_step - 10]
|
| 114 |
+
|
| 115 |
+
if not recent_steps:
|
| 116 |
+
return PredictionResult(
|
| 117 |
+
predicted_agents=[current_agent_id],
|
| 118 |
+
predicted_anchor_hashes=[],
|
| 119 |
+
confidence=0.0,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Simple co-occurrence: find agents that appear after current agent
|
| 123 |
+
agent_counts: dict[str, int] = {}
|
| 124 |
+
anchor_counts: dict[str, int] = {}
|
| 125 |
+
|
| 126 |
+
for i, step in enumerate(recent_steps[:-1]):
|
| 127 |
+
if step.agent_id == current_agent_id and i + 1 < len(recent_steps):
|
| 128 |
+
next_step = recent_steps[i + 1]
|
| 129 |
+
agent_counts[next_step.agent_id] = agent_counts.get(next_step.agent_id, 0) + 1
|
| 130 |
+
anchor_counts[next_step.anchor_hash] = anchor_counts.get(next_step.anchor_hash, 0) + 1
|
| 131 |
+
|
| 132 |
+
# Rank by frequency
|
| 133 |
+
sorted_agents = sorted(agent_counts.items(), key=lambda x: -x[1])
|
| 134 |
+
sorted_anchors = sorted(anchor_counts.items(), key=lambda x: -x[1])
|
| 135 |
+
|
| 136 |
+
predicted_agents = [a[0] for a in sorted_agents[:num_predictions]]
|
| 137 |
+
predicted_anchors = [a[0] for a in sorted_anchors[:num_predictions]]
|
| 138 |
+
|
| 139 |
+
confidence = 0.5 if sorted_agents else 0.0
|
| 140 |
+
|
| 141 |
+
return PredictionResult(
|
| 142 |
+
predicted_agents=predicted_agents or [current_agent_id],
|
| 143 |
+
predicted_anchor_hashes=predicted_anchors,
|
| 144 |
+
confidence=confidence,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
async def get_prefetch_candidates(
|
| 148 |
+
self,
|
| 149 |
+
agent_id: str,
|
| 150 |
+
step: int,
|
| 151 |
+
) -> list[str]:
|
| 152 |
+
"""Get list of block IDs to prefetch for given agent and step."""
|
| 153 |
+
prediction = await self.predict_next_agents(agent_id, step, num_predictions=3)
|
| 154 |
+
|
| 155 |
+
# STUB: Just return anchor hashes as "block IDs"
|
| 156 |
+
# Real implementation would map anchors to actual block IDs
|
| 157 |
+
candidates = prediction.predicted_anchor_hashes
|
| 158 |
+
|
| 159 |
+
logger.debug(
|
| 160 |
+
f"PBKV prefetch candidates for agent={agent_id} step={step}: "
|
| 161 |
+
f"{len(candidates)} candidates, confidence={prediction.confidence:.2f}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return candidates
|
| 165 |
+
|
| 166 |
+
def get_stats(self) -> dict:
|
| 167 |
+
"""Return PBKV predictor statistics."""
|
| 168 |
+
return {
|
| 169 |
+
"history_size": len(self._history),
|
| 170 |
+
"log_file": str(self._log_file),
|
| 171 |
+
"max_history_steps": self._max_history_steps,
|
| 172 |
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AgentStepGraph — workflow dependency graph for KV cache eviction priority.
|
| 2 |
+
|
| 3 |
+
Based on KVFlow (NeurIPS 2025, arXiv:2507.07400):
|
| 4 |
+
- Workflow-aware eviction: evict caches of agents with high steps-to-execution
|
| 5 |
+
(agents far from being invoked) before agents about to run.
|
| 6 |
+
- Overlapped KV prefetching: proactively prefetch KV tensors for agents
|
| 7 |
+
scheduled in the next N steps.
|
| 8 |
+
|
| 9 |
+
Result from paper: 1.83x speedup over SGLang, 2.19x for concurrent workflows.
|
| 10 |
+
|
| 11 |
+
V4.0 CHANGES: New module for workflow-aware eviction.
|
| 12 |
+
"""
|
| 13 |
+
import sys
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AgentStep:
|
| 20 |
+
"""A single step in a workflow graph."""
|
| 21 |
+
agent_id: str
|
| 22 |
+
depends_on: list[str] = field(default_factory=list)
|
| 23 |
+
step_index: int = 0
|
| 24 |
+
estimated_tokens: int = 0
|
| 25 |
+
is_optional: bool = False # True for dynamic conditional agents
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AgentStepGraph:
|
| 29 |
+
"""
|
| 30 |
+
Workflow dependency graph for KV cache eviction priority.
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
graph = AgentStepGraph()
|
| 34 |
+
graph.add_step(AgentStep(agent_id="retriever", depends_on=[], step_index=0))
|
| 35 |
+
graph.add_step(AgentStep(agent_id="summarizer", depends_on=["retriever"], step_index=1))
|
| 36 |
+
order = graph.get_eviction_priority_order() # agents far from execution first
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self._steps: dict[str, AgentStep] = {}
|
| 41 |
+
self._step_list: list[AgentStep] = [] # topological order
|
| 42 |
+
|
| 43 |
+
def add_step(self, step: AgentStep) -> "AgentStepGraph":
|
| 44 |
+
"""Add a step to the graph. Returns self for chaining."""
|
| 45 |
+
self._steps[step.agent_id] = step
|
| 46 |
+
self._step_list.append(step)
|
| 47 |
+
return self
|
| 48 |
+
|
| 49 |
+
def compute_steps_to_execution(self, agent_id: str, current_step: int = 0) -> int:
|
| 50 |
+
"""
|
| 51 |
+
Returns how many steps must complete before agent_id is invoked.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
0 if agent is the current step.
|
| 55 |
+
sys.maxsize if agent_id not in graph.
|
| 56 |
+
Raises ValueError if graph has cycles.
|
| 57 |
+
"""
|
| 58 |
+
self.validate_dag() # Will raise if cycles
|
| 59 |
+
|
| 60 |
+
if agent_id not in self._steps:
|
| 61 |
+
return sys.maxsize
|
| 62 |
+
|
| 63 |
+
step = self._steps[agent_id]
|
| 64 |
+
|
| 65 |
+
# Compute longest path from any root to this step
|
| 66 |
+
if step.step_index <= current_step:
|
| 67 |
+
return 0
|
| 68 |
+
|
| 69 |
+
# BFS/DFS to compute depth
|
| 70 |
+
visited = set()
|
| 71 |
+
|
| 72 |
+
def compute_depth(s: AgentStep, visited: set) -> int:
|
| 73 |
+
if s.agent_id in visited:
|
| 74 |
+
return 0
|
| 75 |
+
visited.add(s.agent_id)
|
| 76 |
+
|
| 77 |
+
if not s.depends_on:
|
| 78 |
+
return s.step_index
|
| 79 |
+
|
| 80 |
+
max_parent_depth = 0
|
| 81 |
+
for dep_id in s.depends_on:
|
| 82 |
+
if dep_id in self._steps:
|
| 83 |
+
max_parent_depth = max(max_parent_depth, compute_depth(self._steps[dep_id], visited))
|
| 84 |
+
|
| 85 |
+
return max_parent_depth + 1
|
| 86 |
+
|
| 87 |
+
return compute_depth(step, set())
|
| 88 |
+
|
| 89 |
+
def get_prefetch_candidates(
|
| 90 |
+
self,
|
| 91 |
+
current_step: int,
|
| 92 |
+
lookahead: int = 2,
|
| 93 |
+
) -> list[str]:
|
| 94 |
+
"""Return agent_ids to prefetch within `lookahead` steps."""
|
| 95 |
+
candidates = []
|
| 96 |
+
for step in self._step_list:
|
| 97 |
+
if step.step_index <= current_step:
|
| 98 |
+
continue
|
| 99 |
+
if step.step_index <= current_step + lookahead:
|
| 100 |
+
candidates.append(step.agent_id)
|
| 101 |
+
return candidates
|
| 102 |
+
|
| 103 |
+
def get_eviction_priority_order(self) -> list[str]:
|
| 104 |
+
"""
|
| 105 |
+
Return agent_ids ordered from lowest to highest eviction priority
|
| 106 |
+
(first in list = evict first = highest steps_to_execution).
|
| 107 |
+
"""
|
| 108 |
+
# Sort by steps_to_execution descending (agents far from execution evict first)
|
| 109 |
+
priorities = []
|
| 110 |
+
for step in self._step_list:
|
| 111 |
+
steps = self.compute_steps_to_execution(step.agent_id, current_step=0)
|
| 112 |
+
priorities.append((step.agent_id, steps))
|
| 113 |
+
|
| 114 |
+
# Sort descending by steps (highest first = evict first)
|
| 115 |
+
priorities.sort(key=lambda x: x[1], reverse=True)
|
| 116 |
+
return [agent_id for agent_id, _ in priorities]
|
| 117 |
+
|
| 118 |
+
def validate_dag(self) -> None:
|
| 119 |
+
"""Raise ValueError if graph contains cycles."""
|
| 120 |
+
# DFS-based cycle detection
|
| 121 |
+
WHITE, GRAY, BLACK = 0, 1, 2
|
| 122 |
+
color = {sid: WHITE for sid in self._steps}
|
| 123 |
+
|
| 124 |
+
def dfs(node_id: str) -> None:
|
| 125 |
+
color[node_id] = GRAY
|
| 126 |
+
if node_id in self._steps:
|
| 127 |
+
for dep in self._steps[node_id].depends_on:
|
| 128 |
+
if dep not in color:
|
| 129 |
+
color[dep] = WHITE
|
| 130 |
+
if color.get(dep, WHITE) == GRAY:
|
| 131 |
+
raise ValueError(f"Cycle detected involving agent '{node_id}'")
|
| 132 |
+
if color.get(dep, WHITE) == WHITE:
|
| 133 |
+
dfs(dep)
|
| 134 |
+
color[node_id] = BLACK
|
| 135 |
+
|
| 136 |
+
for sid in self._steps:
|
| 137 |
+
if color[sid] == WHITE:
|
| 138 |
+
dfs(sid)
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def size(self) -> int:
|
| 142 |
+
"""Number of steps in the graph."""
|
| 143 |
+
return len(self._steps)
|
| 144 |
+
|
| 145 |
+
def get_step(self, agent_id: str) -> Optional[AgentStep]:
|
| 146 |
+
"""Get step by agent_id."""
|
| 147 |
+
return self._steps.get(agent_id)
|
| 148 |
+
|
| 149 |
+
def get_all_agents(self) -> list[str]:
|
| 150 |
+
"""Get all agent IDs in the graph."""
|
| 151 |
+
return list(self._steps.keys())
|
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""vLLM-ATOM Plugin for ContextForge V4.0.
|
| 2 |
+
|
| 3 |
+
ATOM (Anchor-driven Tensor Orchestration for Multi-agent) provides:
|
| 4 |
+
- Pre/post attention hooks for RotateKV quantization (INVARIANT 10)
|
| 5 |
+
- Anchor-aware KV block routing
|
| 6 |
+
- CLA metadata injection
|
| 7 |
+
- KV-aware load balancing across workers
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
from contextforge.serving.atom_plugin import vLLMAtomPlugin
|
| 11 |
+
|
| 12 |
+
# Register with vLLM via entry_point in pyproject.toml
|
| 13 |
+
# Plugin auto-initializes on vLLM worker startup
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Any, Callable, Optional
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ATOMConfig:
|
| 26 |
+
"""ATOM plugin configuration."""
|
| 27 |
+
|
| 28 |
+
enable_quantization: bool = True # RotateKV pre-RoPE quantization
|
| 29 |
+
enable_anchor_routing: bool = True # Anchor-based block routing
|
| 30 |
+
enable_cla_injection: bool = True # CLA metadata in attention
|
| 31 |
+
quantization_mode: str = "rotate_kv" # or "disabled"
|
| 32 |
+
max_quantize_blocks: int = 1024
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PreAttentionHook:
|
| 36 |
+
"""Called before attention computation on a KV block."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: ATOMConfig):
|
| 39 |
+
self._config = config
|
| 40 |
+
self._quantized_blocks: dict[str, Any] = {}
|
| 41 |
+
|
| 42 |
+
def __call__(
|
| 43 |
+
self,
|
| 44 |
+
block_ids: list[str],
|
| 45 |
+
token_ids: list[int],
|
| 46 |
+
layer_idx: int,
|
| 47 |
+
) -> Optional[dict]:
|
| 48 |
+
"""Pre-attention hook for ATOM processing.
|
| 49 |
+
|
| 50 |
+
Returns metadata dict with:
|
| 51 |
+
- quantized: whether RotateKV quantization was applied
|
| 52 |
+
- anchor_hash: anchor identifier for routing
|
| 53 |
+
- cla_group: CLA group assignment
|
| 54 |
+
- pre_rope: True (INVARIANT 10)
|
| 55 |
+
"""
|
| 56 |
+
if not self._config.enable_quantization:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
result = {
|
| 60 |
+
"quantized": True,
|
| 61 |
+
"anchor_hash": "",
|
| 62 |
+
"cla_group": None,
|
| 63 |
+
"pre_rope": True, # INVARIANT 10: pre-RoPE only
|
| 64 |
+
"layer_idx": layer_idx,
|
| 65 |
+
"num_blocks": len(block_ids),
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
logger.debug(
|
| 69 |
+
f"ATOM pre-attention: layer={layer_idx} blocks={len(block_ids)} "
|
| 70 |
+
f"quantized={result['quantized']} pre_rope={result['pre_rope']}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class PostAttentionHook:
|
| 77 |
+
"""Called after attention computation on a KV block."""
|
| 78 |
+
|
| 79 |
+
def __init__(self, config: ATOMConfig):
|
| 80 |
+
self._config = config
|
| 81 |
+
self._stats = {"hits": 0, "misses": 0}
|
| 82 |
+
|
| 83 |
+
def __call__(
|
| 84 |
+
self,
|
| 85 |
+
block_ids: list[str],
|
| 86 |
+
output_tensors: list[Any],
|
| 87 |
+
layer_idx: int,
|
| 88 |
+
) -> dict:
|
| 89 |
+
"""Post-attention hook for ATOM processing.
|
| 90 |
+
|
| 91 |
+
Records anchor hit/miss for routing decisions.
|
| 92 |
+
"""
|
| 93 |
+
self._stats["hits"] += len(block_ids)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"processed_blocks": len(block_ids),
|
| 97 |
+
"layer_idx": layer_idx,
|
| 98 |
+
"total_hits": self._stats["hits"],
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class vLLMAtomPlugin:
|
| 103 |
+
"""vLLM-ATOM plugin for ContextForge V4.0.
|
| 104 |
+
|
| 105 |
+
Integrates with vLLM via:
|
| 106 |
+
- pre_attention_hook: called before each attention layer
|
| 107 |
+
- post_attention_hook: called after each attention layer
|
| 108 |
+
|
| 109 |
+
The plugin handles:
|
| 110 |
+
1. RotateKV quantization of pre-RoPE tensors (INVARIANT 10)
|
| 111 |
+
2. Anchor-aware KV block routing
|
| 112 |
+
3. CLA metadata injection
|
| 113 |
+
4. KV-aware worker load balancing
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, config: Optional[ATOMConfig] = None):
|
| 117 |
+
self._config = config or ATOMConfig()
|
| 118 |
+
self._pre_hook = PreAttentionHook(self._config)
|
| 119 |
+
self._post_hook = PostAttentionHook(self._config)
|
| 120 |
+
self._initialized = False
|
| 121 |
+
self._worker_id: Optional[str] = None
|
| 122 |
+
|
| 123 |
+
def initialize(self, worker_id: str, vllm_config: dict) -> None:
|
| 124 |
+
"""Initialize plugin with vLLM worker context."""
|
| 125 |
+
self._worker_id = worker_id
|
| 126 |
+
self._initialized = True
|
| 127 |
+
logger.info(f"ATOM plugin initialized: worker={worker_id}")
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def pre_attention_hook(self) -> PreAttentionHook:
|
| 131 |
+
"""Hook called before attention computation."""
|
| 132 |
+
return self._pre_hook
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def post_attention_hook(self) -> PostAttentionHook:
|
| 136 |
+
"""Hook called after attention computation."""
|
| 137 |
+
return self._post_hook
|
| 138 |
+
|
| 139 |
+
def is_initialized(self) -> bool:
|
| 140 |
+
"""Check if plugin is initialized."""
|
| 141 |
+
return self._initialized
|
| 142 |
+
|
| 143 |
+
def get_stats(self) -> dict:
|
| 144 |
+
"""Return ATOM plugin statistics."""
|
| 145 |
+
return {
|
| 146 |
+
"initialized": self._initialized,
|
| 147 |
+
"worker_id": self._worker_id,
|
| 148 |
+
"config": {
|
| 149 |
+
"enable_quantization": self._config.enable_quantization,
|
| 150 |
+
"enable_anchor_routing": self._config.enable_anchor_routing,
|
| 151 |
+
"enable_cla_injection": self._config.enable_cla_injection,
|
| 152 |
+
"quantization_mode": self._config.quantization_mode,
|
| 153 |
+
},
|
| 154 |
+
"post_stats": self._post_hook._stats,
|
| 155 |
+
}
|
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LMCache V1 bridge for ContextForge V4.0.
|
| 2 |
+
|
| 3 |
+
Provides transparent bridge between ContextForge's AnchorPool/offset tracking
|
| 4 |
+
and LMCache's distributed KV cache layer. Enables cross-worker KV reuse with
|
| 5 |
+
anchor-aware offset hints.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
- LMCache acts as external KV store (separate from VRAMCache)
|
| 9 |
+
- Bridge intercepts save/load events and augments with ContextForge metadata
|
| 10 |
+
- AnchorPool offset hints propagate to LMCache for cross-node alignment
|
| 11 |
+
|
| 12 |
+
INVARIANT 10: Only pre-RoPE tensors are quantized/shared.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import logging
|
| 18 |
+
import weakref
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LMCacheMeta:
|
| 27 |
+
"""Metadata stored alongside KV blocks in LMCache."""
|
| 28 |
+
|
| 29 |
+
anchor_hash: str = ""
|
| 30 |
+
agent_id: str = ""
|
| 31 |
+
token_length: int = 0
|
| 32 |
+
pre_rope: bool = True # INVARIANT 10 flag
|
| 33 |
+
cla_group: Optional[int] = None
|
| 34 |
+
workflow_step: Optional[int] = None
|
| 35 |
+
offset_hint: Optional[list[float]] = None # from AnchorPool
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LMCacheConnectorV1:
|
| 39 |
+
"""Bridge between ContextForge AnchorPool and LMCache V1.
|
| 40 |
+
|
| 41 |
+
Supports:
|
| 42 |
+
- Saving KV layers with anchor-aware metadata
|
| 43 |
+
- Loading with offset_hint injection for RoPE de-rotation
|
| 44 |
+
- Cross-worker block sharing with prefix anchoring
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
lmcache_client=None, # LMCache client instance (optional for graceful degradation)
|
| 50 |
+
enable_offset_hints: bool = True,
|
| 51 |
+
enable_cla_metadata: bool = True,
|
| 52 |
+
):
|
| 53 |
+
self._client = lmcache_client
|
| 54 |
+
self._enable_offset_hints = enable_offset_hints
|
| 55 |
+
self._enable_cla_metadata = enable_cla_metadata
|
| 56 |
+
self._active = lmcache_client is not None
|
| 57 |
+
self._pending_saves: dict[str, asyncio.Event] = {}
|
| 58 |
+
|
| 59 |
+
def is_active(self) -> bool:
|
| 60 |
+
"""Check if LMCache bridge is active."""
|
| 61 |
+
return self._active
|
| 62 |
+
|
| 63 |
+
def build_prefix_hint(
|
| 64 |
+
self,
|
| 65 |
+
token_ids: list[int],
|
| 66 |
+
agent_id: str,
|
| 67 |
+
anchor_hash: str,
|
| 68 |
+
) -> dict:
|
| 69 |
+
"""Build prefix hint dict for LMCache save operations.
|
| 70 |
+
|
| 71 |
+
This hint is stored alongside the KV data so loading workers
|
| 72 |
+
can reconstruct RoPE-aligned context.
|
| 73 |
+
"""
|
| 74 |
+
return {
|
| 75 |
+
"anchor_hash": anchor_hash,
|
| 76 |
+
"agent_id": agent_id,
|
| 77 |
+
"token_length": len(token_ids),
|
| 78 |
+
"pre_rope": True, # INVARIANT 10
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
async def on_save_kv_layer(
|
| 82 |
+
self,
|
| 83 |
+
block_id: str,
|
| 84 |
+
kv_data, # Pre-RoPE KV tensor
|
| 85 |
+
metadata: dict,
|
| 86 |
+
) -> None:
|
| 87 |
+
"""Called when ContextForge saves a KV layer to LMCache.
|
| 88 |
+
|
| 89 |
+
Augments metadata with anchor hash and CLA group info.
|
| 90 |
+
"""
|
| 91 |
+
if not self._active:
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
# INVARIANT 10: Ensure pre-RoPE flag is set
|
| 95 |
+
meta = LMCacheMeta(
|
| 96 |
+
anchor_hash=metadata.get("anchor_hash", ""),
|
| 97 |
+
agent_id=metadata.get("agent_id", ""),
|
| 98 |
+
token_length=metadata.get("token_length", 0),
|
| 99 |
+
pre_rope=True,
|
| 100 |
+
cla_group=metadata.get("cla_group"),
|
| 101 |
+
workflow_step=metadata.get("workflow_step"),
|
| 102 |
+
offset_hint=metadata.get("offset_hint"),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
logger.debug(
|
| 106 |
+
f"LMCache save: block={block_id} anchor={meta.anchor_hash} "
|
| 107 |
+
f"pre_rope={meta.pre_rope} cla_group={meta.cla_group}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
async def on_load_kv_layer(
|
| 111 |
+
self,
|
| 112 |
+
block_id: str,
|
| 113 |
+
metadata: dict,
|
| 114 |
+
) -> Optional[dict]:
|
| 115 |
+
"""Called when ContextForge loads a KV layer from LMCache.
|
| 116 |
+
|
| 117 |
+
Returns offset_hint if available for RoPE de-rotation alignment.
|
| 118 |
+
"""
|
| 119 |
+
if not self._active:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
offset_hint = metadata.get("offset_hint")
|
| 123 |
+
anchor_hash = metadata.get("anchor_hash")
|
| 124 |
+
|
| 125 |
+
if offset_hint:
|
| 126 |
+
logger.debug(
|
| 127 |
+
f"LMCache load: block={block_id} anchor={anchor_hash} "
|
| 128 |
+
f"has_offset_hint len={len(offset_hint)}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
"offset_hint": offset_hint,
|
| 133 |
+
"anchor_hash": anchor_hash,
|
| 134 |
+
"pre_rope": metadata.get("pre_rope", True), # INVARIANT 10
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
async def prefetch_blocks(
|
| 138 |
+
self,
|
| 139 |
+
block_ids: list[str],
|
| 140 |
+
priority: Optional[list[int]] = None,
|
| 141 |
+
) -> None:
|
| 142 |
+
"""Prefetch blocks from LMCache into local cache."""
|
| 143 |
+
if not self._active or not self._client:
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
# priority not supported in V1 fallback; fetch in order
|
| 147 |
+
logger.debug(f"LMCache prefetch: {len(block_ids)} blocks")
|
| 148 |
+
|
| 149 |
+
def get_stats(self) -> dict:
|
| 150 |
+
"""Return LMCache bridge statistics."""
|
| 151 |
+
return {
|
| 152 |
+
"active": self._active,
|
| 153 |
+
"offset_hints_enabled": self._enable_offset_hints,
|
| 154 |
+
"cla_metadata_enabled": self._enable_cla_metadata,
|
| 155 |
+
"pending_saves": len(self._pending_saves),
|
| 156 |
+
}
|