Spaces:
Sleeping
ContextForge V4.0: Benchmark V4 + 9 test files
Browse files- demo/benchmark_v4.py: 10 scenarios, new V4 metrics
(anchor_pool_hit_rate, cla_vram_reduction_pct, quantization_active,
rotate_kv_blocks, prefetch_hit_rate, pbkv_accuracy,
anchor_locality_score, router_confidence_avg, lmcache_bridge_active,
atom_plugin_initialized)
- tests/test_embedding_engine.py: EmbeddingEngine encode/encode_batch/simhash
- tests/test_cla_metadata.py: CLAMetadataLayer compute_layer_groups/emit_hint
- tests/test_rotate_kv.py: RotateKVQuantizer quantize_pre_rope/dequantize
- tests/test_step_graph.py: AgentStepGraph compute_steps/get_eviction_priority
- tests/test_lmcache_bridge.py: LMCacheConnectorV1 save/load hooks
- tests/test_atom_plugin.py: vLLMAtomPlugin pre/post attention hooks
- tests/test_kv_aware_router.py: KVAwareRouter select_worker/broadcast
- tests/test_pbkv_predictor.py: PBKVPredictor log_workflow_step/predict
INVARIANT 10: pre-RoPE quantization in RotateKV tests.
- demo/benchmark_v4.py +524 -0
- tests/test_atom_plugin.py +86 -0
- tests/test_cla_metadata.py +82 -0
- tests/test_embedding_engine.py +71 -0
- tests/test_kv_aware_router.py +88 -0
- tests/test_lmcache_bridge.py +58 -0
- tests/test_pbkv_predictor.py +113 -0
- tests/test_rotate_kv.py +84 -0
- tests/test_step_graph.py +79 -0
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ContextForge V4.0 Benchmark - 10 scenarios, new V4 metrics.
|
| 2 |
+
|
| 3 |
+
New V4.0 metrics:
|
| 4 |
+
- anchor_pool_hit_rate
|
| 5 |
+
- cla_vram_reduction_pct
|
| 6 |
+
- quantization_active
|
| 7 |
+
- rotate_kv_blocks
|
| 8 |
+
- prefetch_hit_rate
|
| 9 |
+
- pbkv_accuracy
|
| 10 |
+
|
| 11 |
+
INVARIANT 10: Only pre-RoPE tensors are quantized/shared.
|
| 12 |
+
"""
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from typing import Any, Optional
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
# V4.0 imports
|
| 23 |
+
from contextforge.embeddings.embedding_engine import EmbeddingEngine
|
| 24 |
+
from contextforge.kv_offset.anchor_pool import AnchorPool, AnchorOffsetResult
|
| 25 |
+
from contextforge.kv_offset.cla_metadata import CLAMetadataLayer, CLAGroupConfig, CLAHint
|
| 26 |
+
from contextforge.quantization.rotate_kv import RotateKVQuantizer, RotateKVConfig, QuantizedKVBlock
|
| 27 |
+
from contextforge.routing.kv_aware_router import KVAwareRouter, RouteDecision
|
| 28 |
+
from contextforge.scheduling.step_graph import AgentStepGraph, AgentStep
|
| 29 |
+
from contextforge.scheduling.pbkv_predictor import PBKVPredictor
|
| 30 |
+
from contextforge.serving.lmcache_bridge import LMCacheConnectorV1
|
| 31 |
+
from contextforge.serving.atom_plugin import vLLMAtomPlugin, ATOMConfig
|
| 32 |
+
from contextforge.registry.vram_aware_cache import EvictionMode, VRAMAwareCache
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class V4Metrics:
|
| 37 |
+
"""V4.0 benchmark metrics."""
|
| 38 |
+
anchor_pool_hit_rate: float = 0.0
|
| 39 |
+
cla_vram_reduction_pct: float = 0.0
|
| 40 |
+
quantization_active: bool = False
|
| 41 |
+
rotate_kv_blocks: int = 0
|
| 42 |
+
prefetch_hit_rate: float = 0.0
|
| 43 |
+
pbkv_accuracy: float = 0.0
|
| 44 |
+
anchor_locality_score: float = 0.0
|
| 45 |
+
router_confidence_avg: float = 0.0
|
| 46 |
+
lmcache_bridge_active: bool = False
|
| 47 |
+
atom_plugin_initialized: bool = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ScenarioResult:
|
| 52 |
+
"""Result for a single benchmark scenario."""
|
| 53 |
+
scenario_id: int
|
| 54 |
+
scenario_name: str
|
| 55 |
+
duration_ms: float
|
| 56 |
+
tokens_processed: int
|
| 57 |
+
vram_peak_gb: float
|
| 58 |
+
throughput_tps: float
|
| 59 |
+
v4: V4Metrics = field(default_factory=V4Metrics)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
SCENARIOS = [
|
| 63 |
+
{"id": 1, "name": "anchor_pool_resolution", "description": "Test AnchorPool offset approximation"},
|
| 64 |
+
{"id": 2, "name": "cla_metadata_layer", "description": "Test CLA group computation and VRAM reduction"},
|
| 65 |
+
{"id": 3, "name": "rotate_kv_quantization", "description": "Test RotateKV pre-RoPE quantization (INVARIANT 10)"},
|
| 66 |
+
{"id": 4, "name": "step_graph_execution", "description": "Test AgentStepGraph compute_steps_to_execution"},
|
| 67 |
+
{"id": 5, "name": "kv_aware_routing", "description": "Test KVAwareRouter select_worker + anchor locality"},
|
| 68 |
+
{"id": 6, "name": "lmcache_bridge_save_load", "description": "Test LMCacheConnectorV1 on_save/on_load hooks"},
|
| 69 |
+
{"id": 7, "name": "atom_plugin_hooks", "description": "Test vLLMAtomPlugin pre/post attention hooks"},
|
| 70 |
+
{"id": 8, "name": "pbkv_prediction", "description": "Test PBKVPredictor log_workflow_step + predict_next_agents"},
|
| 71 |
+
{"id": 9, "name": "workflow_aware_eviction", "description": "Test _pressure_to_mode WORKFLOW_AWARE at high pressure"},
|
| 72 |
+
{"id": 10, "name": "embedding_engine_encoding", "description": "Test EmbeddingEngine.encode_batch + simhash"},
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def tokens_to_text(token_ids: list[int]) -> str:
|
| 77 |
+
"""Convert token IDs to text string for embedding encoding."""
|
| 78 |
+
return " ".join(str(t) for t in token_ids)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def tokens_to_text_batch(sequences: list[list[int]]) -> list[str]:
|
| 82 |
+
"""Convert token ID sequences to text strings."""
|
| 83 |
+
return [tokens_to_text(seq) for seq in sequences]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def scenario_1_anchor_pool_resolution() -> ScenarioResult:
|
| 87 |
+
"""Scenario 1: AnchorPool offset resolution."""
|
| 88 |
+
pool = AnchorPool(max_size=20)
|
| 89 |
+
token_ids = [101, 2003, 1996, 3007, 102]
|
| 90 |
+
|
| 91 |
+
# Use np.ndarray for real_kv_offset as per API
|
| 92 |
+
offsets = [
|
| 93 |
+
np.array([1.0, 2.0, 3.0], dtype=np.float32),
|
| 94 |
+
np.array([1.1, 2.1, 3.1], dtype=np.float32),
|
| 95 |
+
np.array([0.9, 1.9, 2.9], dtype=np.float32),
|
| 96 |
+
]
|
| 97 |
+
for i, offset in enumerate(offsets):
|
| 98 |
+
await pool.update_pool(token_ids, f"agent_{i+1}", offset)
|
| 99 |
+
await asyncio.sleep(0.001)
|
| 100 |
+
|
| 101 |
+
start = time.perf_counter()
|
| 102 |
+
for _ in range(100):
|
| 103 |
+
result = await pool.approximate_offset(token_ids, "agent_1")
|
| 104 |
+
duration = (time.perf_counter() - start) * 1000
|
| 105 |
+
|
| 106 |
+
stats = await pool.get_stats()
|
| 107 |
+
hit_rate = stats["total_anchors"] / max(stats["total_agent_offsets"], 1)
|
| 108 |
+
|
| 109 |
+
return ScenarioResult(
|
| 110 |
+
scenario_id=1,
|
| 111 |
+
scenario_name="anchor_pool_resolution",
|
| 112 |
+
duration_ms=duration,
|
| 113 |
+
tokens_processed=len(token_ids) * 100,
|
| 114 |
+
vram_peak_gb=0.1,
|
| 115 |
+
throughput_tps=(len(token_ids) * 100) / (duration / 1000),
|
| 116 |
+
v4=V4Metrics(anchor_pool_hit_rate=min(hit_rate, 1.0)),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
async def scenario_2_cla_metadata_layer() -> ScenarioResult:
|
| 121 |
+
"""Scenario 2: CLA metadata layer VRAM reduction."""
|
| 122 |
+
config = CLAGroupConfig(
|
| 123 |
+
group_size=2,
|
| 124 |
+
sharing_direction="upper",
|
| 125 |
+
thinking_mode_bypass=True,
|
| 126 |
+
min_layer=0,
|
| 127 |
+
max_layer=64,
|
| 128 |
+
)
|
| 129 |
+
layer = CLAMetadataLayer(config)
|
| 130 |
+
|
| 131 |
+
start = time.perf_counter()
|
| 132 |
+
groups = []
|
| 133 |
+
for _ in range(50):
|
| 134 |
+
groups = layer.compute_layer_groups(model_layer_count=32, agent_role="retriever")
|
| 135 |
+
hint = layer.emit_hint(
|
| 136 |
+
agent_id="test_agent",
|
| 137 |
+
model_id="Qwen3.6-35B-A22B",
|
| 138 |
+
is_thinking_mode=False,
|
| 139 |
+
model_layer_count=32,
|
| 140 |
+
agent_role="retriever",
|
| 141 |
+
)
|
| 142 |
+
duration = (time.perf_counter() - start) * 1000
|
| 143 |
+
|
| 144 |
+
vram_reduction = layer.estimated_vram_reduction(groups)
|
| 145 |
+
|
| 146 |
+
return ScenarioResult(
|
| 147 |
+
scenario_id=2,
|
| 148 |
+
scenario_name="cla_metadata_layer",
|
| 149 |
+
duration_ms=duration,
|
| 150 |
+
tokens_processed=32 * 50,
|
| 151 |
+
vram_peak_gb=0.05,
|
| 152 |
+
throughput_tps=(32 * 50) / (duration / 1000),
|
| 153 |
+
v4=V4Metrics(cla_vram_reduction_pct=vram_reduction * 100),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
async def scenario_3_rotate_kv_quantization() -> ScenarioResult:
|
| 158 |
+
"""Scenario 3: RotateKV pre-RoPE quantization (INVARIANT 10)."""
|
| 159 |
+
config = RotateKVConfig(
|
| 160 |
+
bits=4,
|
| 161 |
+
group_size=64,
|
| 162 |
+
sink_tokens=4,
|
| 163 |
+
use_fwht=True,
|
| 164 |
+
grouped_heads=2,
|
| 165 |
+
)
|
| 166 |
+
quantizer = RotateKVQuantizer(config)
|
| 167 |
+
|
| 168 |
+
# Create pre-RoPE tensors (INVARIANT 10: must be pre-RoPE)
|
| 169 |
+
num_blocks = 64
|
| 170 |
+
hidden_dim = 512
|
| 171 |
+
k_tensor = np.random.randn(num_blocks, hidden_dim).astype(np.float32)
|
| 172 |
+
v_tensor = np.random.randn(num_blocks, hidden_dim).astype(np.float32)
|
| 173 |
+
positions = np.arange(num_blocks, dtype=np.float32)
|
| 174 |
+
|
| 175 |
+
start = time.perf_counter()
|
| 176 |
+
qblock = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
|
| 177 |
+
duration = (time.perf_counter() - start) * 1000
|
| 178 |
+
|
| 179 |
+
return ScenarioResult(
|
| 180 |
+
scenario_id=3,
|
| 181 |
+
scenario_name="rotate_kv_quantization",
|
| 182 |
+
duration_ms=duration,
|
| 183 |
+
tokens_processed=num_blocks * hidden_dim,
|
| 184 |
+
vram_peak_gb=0.2,
|
| 185 |
+
throughput_tps=(num_blocks * hidden_dim) / (duration / 1000),
|
| 186 |
+
v4=V4Metrics(quantization_active=True, rotate_kv_blocks=num_blocks),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def scenario_4_step_graph_execution() -> ScenarioResult:
|
| 191 |
+
"""Scenario 4: AgentStepGraph compute_steps_to_execution."""
|
| 192 |
+
graph = AgentStepGraph()
|
| 193 |
+
|
| 194 |
+
# Build workflow: retriever -> summarizer -> critic -> responder
|
| 195 |
+
graph.add_step(AgentStep(agent_id="retriever", depends_on=[], step_index=0, estimated_tokens=100))
|
| 196 |
+
graph.add_step(AgentStep(agent_id="summarizer", depends_on=["retriever"], step_index=1, estimated_tokens=150))
|
| 197 |
+
graph.add_step(AgentStep(agent_id="critic", depends_on=["summarizer"], step_index=2, estimated_tokens=200))
|
| 198 |
+
graph.add_step(AgentStep(agent_id="responder", depends_on=["critic"], step_index=3, estimated_tokens=300))
|
| 199 |
+
|
| 200 |
+
start = time.perf_counter()
|
| 201 |
+
depths = []
|
| 202 |
+
for _ in range(100):
|
| 203 |
+
d = graph.compute_steps_to_execution("responder", current_step=0)
|
| 204 |
+
depths.append(d)
|
| 205 |
+
duration = (time.perf_counter() - start) * 1000
|
| 206 |
+
|
| 207 |
+
prefetch = graph.get_prefetch_candidates(current_step=0)
|
| 208 |
+
|
| 209 |
+
return ScenarioResult(
|
| 210 |
+
scenario_id=4,
|
| 211 |
+
scenario_name="step_graph_execution",
|
| 212 |
+
duration_ms=duration,
|
| 213 |
+
tokens_processed=100,
|
| 214 |
+
vram_peak_gb=0.3,
|
| 215 |
+
throughput_tps=100 / (duration / 1000),
|
| 216 |
+
v4=V4Metrics(prefetch_hit_rate=len(prefetch) / 4.0),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
async def scenario_5_kv_aware_routing() -> ScenarioResult:
|
| 221 |
+
"""Scenario 5: KVAwareRouter anchor locality + CLA affinity."""
|
| 222 |
+
router = KVAwareRouter(num_workers=4, enable_cla_affinity=True)
|
| 223 |
+
|
| 224 |
+
for i in range(4):
|
| 225 |
+
router.register_worker(f"worker_{i}")
|
| 226 |
+
|
| 227 |
+
anchor_hashes = [f"anchor_{i % 3}" for i in range(10)]
|
| 228 |
+
cla_groups = [i % 4 for i in range(10)]
|
| 229 |
+
|
| 230 |
+
start = time.perf_counter()
|
| 231 |
+
decisions = []
|
| 232 |
+
for i, (ah, cg) in enumerate(zip(anchor_hashes, cla_groups)):
|
| 233 |
+
decision = await router.select_worker(ah, cla_group=cg, workflow_step=i)
|
| 234 |
+
decisions.append(decision)
|
| 235 |
+
duration = (time.perf_counter() - start) * 1000
|
| 236 |
+
|
| 237 |
+
avg_confidence = sum(d.confidence for d in decisions) / len(decisions) if decisions else 0
|
| 238 |
+
anchor_locality = sum(1 for d in decisions if d.confidence >= 0.9) / len(decisions)
|
| 239 |
+
|
| 240 |
+
return ScenarioResult(
|
| 241 |
+
scenario_id=5,
|
| 242 |
+
scenario_name="kv_aware_routing",
|
| 243 |
+
duration_ms=duration,
|
| 244 |
+
tokens_processed=len(anchor_hashes),
|
| 245 |
+
vram_peak_gb=0.1,
|
| 246 |
+
throughput_tps=len(anchor_hashes) / (duration / 1000),
|
| 247 |
+
v4=V4Metrics(anchor_locality_score=anchor_locality, router_confidence_avg=avg_confidence),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
async def scenario_6_lmcache_bridge_save_load() -> ScenarioResult:
|
| 252 |
+
"""Scenario 6: LMCacheConnectorV1 save/load hooks."""
|
| 253 |
+
bridge = LMCacheConnectorV1(enable_offset_hints=True, enable_cla_metadata=True)
|
| 254 |
+
|
| 255 |
+
assert bridge.is_active() == False # No LMCache client — graceful degradation
|
| 256 |
+
|
| 257 |
+
metadata = {
|
| 258 |
+
"anchor_hash": "test_anchor",
|
| 259 |
+
"agent_id": "agent_1",
|
| 260 |
+
"token_length": 100,
|
| 261 |
+
"cla_group": 2,
|
| 262 |
+
"offset_hint": [1.0, 2.0, 3.0],
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
start = time.perf_counter()
|
| 266 |
+
for _ in range(100):
|
| 267 |
+
await bridge.on_save_kv_layer("block_0", None, metadata)
|
| 268 |
+
result = await bridge.on_load_kv_layer("block_0", metadata)
|
| 269 |
+
duration = (time.perf_counter() - start) * 1000
|
| 270 |
+
|
| 271 |
+
stats = bridge.get_stats()
|
| 272 |
+
|
| 273 |
+
return ScenarioResult(
|
| 274 |
+
scenario_id=6,
|
| 275 |
+
scenario_name="lmcache_bridge_save_load",
|
| 276 |
+
duration_ms=duration,
|
| 277 |
+
tokens_processed=100,
|
| 278 |
+
vram_peak_gb=0.05,
|
| 279 |
+
throughput_tps=100 / (duration / 1000),
|
| 280 |
+
v4=V4Metrics(lmcache_bridge_active=stats["active"]),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
async def scenario_7_atom_plugin_hooks() -> ScenarioResult:
|
| 285 |
+
"""Scenario 7: vLLMAtomPlugin pre/post attention hooks."""
|
| 286 |
+
config = ATOMConfig(
|
| 287 |
+
enable_quantization=True,
|
| 288 |
+
enable_anchor_routing=True,
|
| 289 |
+
enable_cla_injection=True,
|
| 290 |
+
)
|
| 291 |
+
plugin = vLLMAtomPlugin(config)
|
| 292 |
+
plugin.initialize("worker_0", {})
|
| 293 |
+
|
| 294 |
+
block_ids = [f"b_{i}" for i in range(16)]
|
| 295 |
+
token_ids = [101, 2003, 1996, 3007] * 4
|
| 296 |
+
|
| 297 |
+
start = time.perf_counter()
|
| 298 |
+
for _ in range(50):
|
| 299 |
+
pre_result = plugin.pre_attention_hook(block_ids, token_ids, layer_idx=0)
|
| 300 |
+
post_result = plugin.post_attention_hook(block_ids, [], layer_idx=0)
|
| 301 |
+
duration = (time.perf_counter() - start) * 1000
|
| 302 |
+
|
| 303 |
+
stats = plugin.get_stats()
|
| 304 |
+
|
| 305 |
+
return ScenarioResult(
|
| 306 |
+
scenario_id=7,
|
| 307 |
+
scenario_name="atom_plugin_hooks",
|
| 308 |
+
duration_ms=duration,
|
| 309 |
+
tokens_processed=len(token_ids) * 50,
|
| 310 |
+
vram_peak_gb=0.1,
|
| 311 |
+
throughput_tps=(len(token_ids) * 50) / (duration / 1000),
|
| 312 |
+
v4=V4Metrics(atom_plugin_initialized=stats["initialized"]),
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
async def scenario_8_pbkv_prediction() -> ScenarioResult:
|
| 317 |
+
"""Scenario 8: PBKVPredictor log + predict."""
|
| 318 |
+
predictor = PBKVPredictor(log_dir="/tmp/.pbkv_test_logs", max_history_steps=100)
|
| 319 |
+
|
| 320 |
+
# Log workflow steps
|
| 321 |
+
for i in range(20):
|
| 322 |
+
await predictor.log_workflow_step(
|
| 323 |
+
step_idx=i,
|
| 324 |
+
agent_id=f"agent_{i % 3}",
|
| 325 |
+
anchor_hash=f"anchor_{i % 5}",
|
| 326 |
+
token_length=100 + i,
|
| 327 |
+
cla_group=i % 4,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
start = time.perf_counter()
|
| 331 |
+
predictions = []
|
| 332 |
+
for _ in range(50):
|
| 333 |
+
pred = await predictor.predict_next_agents("agent_0", current_step=10, num_predictions=3)
|
| 334 |
+
predictions.append(pred)
|
| 335 |
+
duration = (time.perf_counter() - start) * 1000
|
| 336 |
+
|
| 337 |
+
avg_confidence = sum(p.confidence for p in predictions) / len(predictions)
|
| 338 |
+
|
| 339 |
+
prefetch = await predictor.get_prefetch_candidates("agent_0", step=10)
|
| 340 |
+
|
| 341 |
+
return ScenarioResult(
|
| 342 |
+
scenario_id=8,
|
| 343 |
+
scenario_name="pbkv_prediction",
|
| 344 |
+
duration_ms=duration,
|
| 345 |
+
tokens_processed=20 + 50,
|
| 346 |
+
vram_peak_gb=0.05,
|
| 347 |
+
throughput_tps=(20 + 50) / (duration / 1000),
|
| 348 |
+
v4=V4Metrics(pbkv_accuracy=avg_confidence),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
async def scenario_9_workflow_aware_eviction() -> ScenarioResult:
|
| 353 |
+
"""Scenario 9: _pressure_to_mode WORKFLOW_AWARE at high pressure."""
|
| 354 |
+
from contextforge.scheduling.step_graph import AgentStepGraph as StepGraph
|
| 355 |
+
|
| 356 |
+
graph = StepGraph()
|
| 357 |
+
graph.add_step(AgentStep(agent_id="a", step_index=0))
|
| 358 |
+
graph.add_step(AgentStep(agent_id="b", step_index=1, depends_on=["a"]))
|
| 359 |
+
graph.add_step(AgentStep(agent_id="c", step_index=2, depends_on=["b"]))
|
| 360 |
+
|
| 361 |
+
start = time.perf_counter()
|
| 362 |
+
modes = []
|
| 363 |
+
for _ in range(100):
|
| 364 |
+
# Test WORKFLOW_AWARE at pressure >= 0.96 with step_graph
|
| 365 |
+
m = VRAMAwareCache._pressure_to_mode(0.97, graph)
|
| 366 |
+
modes.append(m)
|
| 367 |
+
duration = (time.perf_counter() - start) * 1000
|
| 368 |
+
|
| 369 |
+
workflow_aware_count = sum(1 for m in modes if m == EvictionMode.WORKFLOW_AWARE)
|
| 370 |
+
|
| 371 |
+
return ScenarioResult(
|
| 372 |
+
scenario_id=9,
|
| 373 |
+
scenario_name="workflow_aware_eviction",
|
| 374 |
+
duration_ms=duration,
|
| 375 |
+
tokens_processed=100,
|
| 376 |
+
vram_peak_gb=0.1,
|
| 377 |
+
throughput_tps=100 / (duration / 1000),
|
| 378 |
+
v4=V4Metrics(prefetch_hit_rate=workflow_aware_count / 100.0),
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
async def scenario_10_embedding_engine_encoding() -> ScenarioResult:
|
| 383 |
+
"""Scenario 10: EmbeddingEngine encode_batch + simhash."""
|
| 384 |
+
engine = await EmbeddingEngine.get_instance()
|
| 385 |
+
|
| 386 |
+
sequences = [[101, 2003, 1996, 3007, 102] * (i + 1) for i in range(10)]
|
| 387 |
+
|
| 388 |
+
start = time.perf_counter()
|
| 389 |
+
for _ in range(20):
|
| 390 |
+
text_batch = tokens_to_text_batch(sequences)
|
| 391 |
+
embeddings = await engine.encode_batch(text_batch)
|
| 392 |
+
hashes = [await engine.simhash(seq) for seq in sequences]
|
| 393 |
+
duration = (time.perf_counter() - start) * 1000
|
| 394 |
+
|
| 395 |
+
total_tokens = sum(len(s) for s in sequences) * 20
|
| 396 |
+
|
| 397 |
+
return ScenarioResult(
|
| 398 |
+
scenario_id=10,
|
| 399 |
+
scenario_name="embedding_engine_encoding",
|
| 400 |
+
duration_ms=duration,
|
| 401 |
+
tokens_processed=total_tokens,
|
| 402 |
+
vram_peak_gb=0.1,
|
| 403 |
+
throughput_tps=total_tokens / (duration / 1000),
|
| 404 |
+
v4=V4Metrics(anchor_pool_hit_rate=1.0),
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
async def run_all_scenarios() -> list[ScenarioResult]:
|
| 409 |
+
"""Run all 10 benchmark scenarios."""
|
| 410 |
+
results = []
|
| 411 |
+
|
| 412 |
+
scenario_funcs = [
|
| 413 |
+
scenario_1_anchor_pool_resolution,
|
| 414 |
+
scenario_2_cla_metadata_layer,
|
| 415 |
+
scenario_3_rotate_kv_quantization,
|
| 416 |
+
scenario_4_step_graph_execution,
|
| 417 |
+
scenario_5_kv_aware_routing,
|
| 418 |
+
scenario_6_lmcache_bridge_save_load,
|
| 419 |
+
scenario_7_atom_plugin_hooks,
|
| 420 |
+
scenario_8_pbkv_prediction,
|
| 421 |
+
scenario_9_workflow_aware_eviction,
|
| 422 |
+
scenario_10_embedding_engine_encoding,
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
for i, func in enumerate(scenario_funcs):
|
| 426 |
+
print(f" Scenario {i+1}/10: {SCENARIOS[i]['name']}...", end=" ")
|
| 427 |
+
try:
|
| 428 |
+
result = await func()
|
| 429 |
+
results.append(result)
|
| 430 |
+
print(f"OK ({result.duration_ms:.2f}ms, {result.throughput_tps:.0f} tok/s)")
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"FAILED: {e}")
|
| 433 |
+
results.append(ScenarioResult(
|
| 434 |
+
scenario_id=i+1,
|
| 435 |
+
scenario_name=SCENARIOS[i]['name'],
|
| 436 |
+
duration_ms=0, tokens_processed=0, vram_peak_gb=0, throughput_tps=0,
|
| 437 |
+
))
|
| 438 |
+
|
| 439 |
+
return results
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def print_summary(results: list[ScenarioResult]) -> None:
|
| 443 |
+
"""Print benchmark summary."""
|
| 444 |
+
print("\n" + "=" * 80)
|
| 445 |
+
print("CONTEXTFORGE V4.0 BENCHMARK SUMMARY")
|
| 446 |
+
print("=" * 80)
|
| 447 |
+
print(f"{'#':<3} {'Scenario':<35} {'Time(ms)':<10} {'TPS':<12} {'VRAM(GB)':<10}")
|
| 448 |
+
print("-" * 80)
|
| 449 |
+
|
| 450 |
+
total_vram = 0.0
|
| 451 |
+
for r in results:
|
| 452 |
+
print(f"{r.scenario_id:<3} {r.scenario_name:<35} {r.duration_ms:<10.2f} {r.throughput_tps:<12.0f} {r.vram_peak_gb:<10.2f}")
|
| 453 |
+
total_vram += r.vram_peak_gb
|
| 454 |
+
|
| 455 |
+
print("-" * 80)
|
| 456 |
+
print(f"{'TOTAL':<38} {'':<10} {'':<12} {total_vram:<10.2f}")
|
| 457 |
+
|
| 458 |
+
print("\n" + "=" * 80)
|
| 459 |
+
print("V4.0 NEW METRICS")
|
| 460 |
+
print("=" * 80)
|
| 461 |
+
for r in results:
|
| 462 |
+
v4 = r.v4
|
| 463 |
+
print(f"\n{r.scenario_name}:")
|
| 464 |
+
print(f" anchor_pool_hit_rate: {v4.anchor_pool_hit_rate:.3f}")
|
| 465 |
+
print(f" cla_vram_reduction_pct: {v4.cla_vram_reduction_pct:.2f}%")
|
| 466 |
+
print(f" quantization_active: {v4.quantization_active}")
|
| 467 |
+
print(f" rotate_kv_blocks: {v4.rotate_kv_blocks}")
|
| 468 |
+
print(f" prefetch_hit_rate: {v4.prefetch_hit_rate:.3f}")
|
| 469 |
+
print(f" pbkv_accuracy: {v4.pbkv_accuracy:.3f}")
|
| 470 |
+
print(f" anchor_locality_score: {v4.anchor_locality_score:.3f}")
|
| 471 |
+
print(f" router_confidence_avg: {v4.router_confidence_avg:.3f}")
|
| 472 |
+
print(f" lmcache_bridge_active: {v4.lmcache_bridge_active}")
|
| 473 |
+
print(f" atom_plugin_init: {v4.atom_plugin_initialized}")
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
async def main():
|
| 477 |
+
print("\n" + "=" * 80)
|
| 478 |
+
print("CONTEXTFORGE V4.0 BENCHMARK")
|
| 479 |
+
print("=" * 80)
|
| 480 |
+
print(f"Date: {datetime.now().isoformat()}")
|
| 481 |
+
print(f"Scenarios: {len(SCENARIOS)}")
|
| 482 |
+
print(f"INVARIANT 10: pre-RoPE quantization only\n")
|
| 483 |
+
|
| 484 |
+
results = await run_all_scenarios()
|
| 485 |
+
print_summary(results)
|
| 486 |
+
|
| 487 |
+
output = {
|
| 488 |
+
"timestamp": datetime.now().isoformat(),
|
| 489 |
+
"version": "4.0",
|
| 490 |
+
"scenarios": [
|
| 491 |
+
{
|
| 492 |
+
"id": r.scenario_id,
|
| 493 |
+
"name": r.scenario_name,
|
| 494 |
+
"duration_ms": r.duration_ms,
|
| 495 |
+
"tokens_processed": r.tokens_processed,
|
| 496 |
+
"vram_peak_gb": r.vram_peak_gb,
|
| 497 |
+
"throughput_tps": r.throughput_tps,
|
| 498 |
+
"v4_metrics": {
|
| 499 |
+
"anchor_pool_hit_rate": r.v4.anchor_pool_hit_rate,
|
| 500 |
+
"cla_vram_reduction_pct": r.v4.cla_vram_reduction_pct,
|
| 501 |
+
"quantization_active": r.v4.quantization_active,
|
| 502 |
+
"rotate_kv_blocks": r.v4.rotate_kv_blocks,
|
| 503 |
+
"prefetch_hit_rate": r.v4.prefetch_hit_rate,
|
| 504 |
+
"pbkv_accuracy": r.v4.pbkv_accuracy,
|
| 505 |
+
"anchor_locality_score": r.v4.anchor_locality_score,
|
| 506 |
+
"router_confidence_avg": r.v4.router_confidence_avg,
|
| 507 |
+
"lmcache_bridge_active": r.v4.lmcache_bridge_active,
|
| 508 |
+
"atom_plugin_initialized": r.v4.atom_plugin_initialized,
|
| 509 |
+
},
|
| 510 |
+
}
|
| 511 |
+
for r in results
|
| 512 |
+
],
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
output_path = "/home/linconx/Apohara-ContextForge/demo/benchmark_v4_results.json"
|
| 516 |
+
with open(output_path, "w") as f:
|
| 517 |
+
json.dump(output, f, indent=2)
|
| 518 |
+
|
| 519 |
+
print(f"\nResults saved to: {output_path}")
|
| 520 |
+
print("=" * 80 + "\n")
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
if __name__ == "__main__":
|
| 524 |
+
asyncio.run(main())
|
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for vLLMAtomPlugin — TASK-008."""
|
| 2 |
+
import pytest
|
| 3 |
+
from contextforge.serving.atom_plugin import vLLMAtomPlugin, ATOMConfig, PreAttentionHook, PostAttentionHook
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestATOMConfig:
|
| 7 |
+
"""Tests for ATOMConfig."""
|
| 8 |
+
|
| 9 |
+
def test_atom_config_defaults(self):
|
| 10 |
+
"""ATOMConfig has sensible defaults."""
|
| 11 |
+
config = ATOMConfig()
|
| 12 |
+
assert config.enable_quantization == True
|
| 13 |
+
assert config.enable_anchor_routing == True
|
| 14 |
+
assert config.enable_cla_injection == True
|
| 15 |
+
assert config.quantization_mode == "rotate_kv"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestvLLMAtomPlugin:
|
| 19 |
+
"""Tests for vLLMAtomPlugin."""
|
| 20 |
+
|
| 21 |
+
def test_plugin_initialization(self):
|
| 22 |
+
"""Plugin initializes with ATOMConfig."""
|
| 23 |
+
config = ATOMConfig()
|
| 24 |
+
plugin = vLLMAtomPlugin(config)
|
| 25 |
+
assert plugin._config is config
|
| 26 |
+
assert plugin.is_initialized() == False
|
| 27 |
+
|
| 28 |
+
def test_initialize_sets_worker_id(self):
|
| 29 |
+
"""initialize() sets worker_id and marks initialized."""
|
| 30 |
+
config = ATOMConfig()
|
| 31 |
+
plugin = vLLMAtomPlugin(config)
|
| 32 |
+
plugin.initialize("worker_0", {})
|
| 33 |
+
assert plugin.is_initialized() == True
|
| 34 |
+
stats = plugin.get_stats()
|
| 35 |
+
assert stats["worker_id"] == "worker_0"
|
| 36 |
+
assert stats["initialized"] == True
|
| 37 |
+
|
| 38 |
+
def test_pre_attention_hook_returns_dict(self):
|
| 39 |
+
"""pre_attention_hook returns metadata dict."""
|
| 40 |
+
config = ATOMConfig(enable_quantization=True)
|
| 41 |
+
hook = PreAttentionHook(config)
|
| 42 |
+
result = hook(["b0", "b1"], [101, 2003], layer_idx=0)
|
| 43 |
+
assert isinstance(result, dict)
|
| 44 |
+
assert result["quantized"] == True
|
| 45 |
+
assert result["pre_rope"] == True # INVARIANT 10
|
| 46 |
+
assert result["layer_idx"] == 0
|
| 47 |
+
|
| 48 |
+
def test_post_attention_hook_returns_dict(self):
|
| 49 |
+
"""post_attention_hook returns stats dict."""
|
| 50 |
+
config = ATOMConfig()
|
| 51 |
+
hook = PostAttentionHook(config)
|
| 52 |
+
result = hook(["b0", "b1"], [], layer_idx=0)
|
| 53 |
+
assert isinstance(result, dict)
|
| 54 |
+
assert result["processed_blocks"] == 2
|
| 55 |
+
assert result["layer_idx"] == 0
|
| 56 |
+
|
| 57 |
+
def test_plugin_pre_attention_hook_property(self):
|
| 58 |
+
"""Plugin exposes pre_attention_hook as property."""
|
| 59 |
+
config = ATOMConfig()
|
| 60 |
+
plugin = vLLMAtomPlugin(config)
|
| 61 |
+
assert hasattr(plugin, "pre_attention_hook")
|
| 62 |
+
assert callable(plugin.pre_attention_hook)
|
| 63 |
+
|
| 64 |
+
def test_plugin_post_attention_hook_property(self):
|
| 65 |
+
"""Plugin exposes post_attention_hook as property."""
|
| 66 |
+
config = ATOMConfig()
|
| 67 |
+
plugin = vLLMAtomPlugin(config)
|
| 68 |
+
assert hasattr(plugin, "post_attention_hook")
|
| 69 |
+
assert callable(plugin.post_attention_hook)
|
| 70 |
+
|
| 71 |
+
def test_get_stats_returns_config_and_state(self):
|
| 72 |
+
"""get_stats returns configuration and state."""
|
| 73 |
+
config = ATOMConfig(
|
| 74 |
+
enable_quantization=True,
|
| 75 |
+
enable_anchor_routing=False,
|
| 76 |
+
enable_cla_injection=True,
|
| 77 |
+
quantization_mode="rotate_kv",
|
| 78 |
+
)
|
| 79 |
+
plugin = vLLMAtomPlugin(config)
|
| 80 |
+
plugin.initialize("worker_test", {})
|
| 81 |
+
|
| 82 |
+
stats = plugin.get_stats()
|
| 83 |
+
assert stats["initialized"] == True
|
| 84 |
+
assert stats["worker_id"] == "worker_test"
|
| 85 |
+
assert stats["config"]["enable_quantization"] == True
|
| 86 |
+
assert stats["config"]["quantization_mode"] == "rotate_kv"
|
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for CLAMetadataLayer — TASK-004."""
|
| 2 |
+
import pytest
|
| 3 |
+
from contextforge.kv_offset.cla_metadata import CLAMetadataLayer, CLAGroupConfig, CLAHint, NON_THOUGHT_ROLES
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestCLAMetadataLayer:
|
| 7 |
+
"""Tests for CLA metadata layer."""
|
| 8 |
+
|
| 9 |
+
def test_non_thought_roles_frozenset(self):
|
| 10 |
+
"""NON_THOUGHT_ROLES is a frozenset with expected members."""
|
| 11 |
+
assert isinstance(NON_THOUGHT_ROLES, frozenset)
|
| 12 |
+
assert "retriever" in NON_THOUGHT_ROLES
|
| 13 |
+
assert "summarizer" in NON_THOUGHT_ROLES
|
| 14 |
+
assert "critic" not in NON_THOUGHT_ROLES # thinking agent
|
| 15 |
+
|
| 16 |
+
def test_cla_group_config_defaults(self):
|
| 17 |
+
"""CLAGroupConfig has sensible defaults."""
|
| 18 |
+
config = CLAGroupConfig()
|
| 19 |
+
assert config.group_size == 2
|
| 20 |
+
assert config.sharing_direction == "upper"
|
| 21 |
+
assert config.thinking_mode_bypass == True
|
| 22 |
+
|
| 23 |
+
@pytest.mark.asyncio
|
| 24 |
+
async def test_compute_layer_groups_upper_direction(self):
|
| 25 |
+
"""compute_layer_groups returns upper-layer sharing pairs."""
|
| 26 |
+
config = CLAGroupConfig(group_size=2, sharing_direction="upper", min_layer=0, max_layer=64)
|
| 27 |
+
layer = CLAMetadataLayer(config)
|
| 28 |
+
groups = layer.compute_layer_groups(model_layer_count=32, agent_role="retriever")
|
| 29 |
+
assert len(groups) > 0
|
| 30 |
+
# Each group: (start, shared_kv_layer)
|
| 31 |
+
for start, shared in groups:
|
| 32 |
+
assert shared > start # upper direction: KV from higher layer
|
| 33 |
+
|
| 34 |
+
@pytest.mark.asyncio
|
| 35 |
+
async def test_compute_layer_groups_non_thinking_only(self):
|
| 36 |
+
"""compute_layer_groups returns empty for thinking agents."""
|
| 37 |
+
config = CLAGroupConfig(group_size=2, thinking_mode_bypass=True)
|
| 38 |
+
layer = CLAMetadataLayer(config)
|
| 39 |
+
groups = layer.compute_layer_groups(model_layer_count=32, agent_role="retriever")
|
| 40 |
+
assert len(groups) > 0 # retriever is non-thinking
|
| 41 |
+
groups_thinking = layer.compute_layer_groups(model_layer_count=32, agent_role="critic")
|
| 42 |
+
assert len(groups_thinking) == 0 # critic is thinking
|
| 43 |
+
|
| 44 |
+
def test_emit_hint_returns_cla_hint(self):
|
| 45 |
+
"""emit_hint returns CLAHint with correct fields."""
|
| 46 |
+
config = CLAGroupConfig(group_size=2)
|
| 47 |
+
layer = CLAMetadataLayer(config)
|
| 48 |
+
hint = layer.emit_hint(
|
| 49 |
+
agent_id="agent1",
|
| 50 |
+
model_id="Qwen3.6-35B-A22B",
|
| 51 |
+
is_thinking_mode=False,
|
| 52 |
+
model_layer_count=32,
|
| 53 |
+
agent_role="retriever",
|
| 54 |
+
)
|
| 55 |
+
assert isinstance(hint, CLAHint)
|
| 56 |
+
assert hint.agent_id == "agent1"
|
| 57 |
+
assert hint.model_id == "Qwen3.6-35B-A22B"
|
| 58 |
+
assert hint.is_thinking_mode == False
|
| 59 |
+
assert len(hint.layer_groups) > 0
|
| 60 |
+
|
| 61 |
+
def test_emit_hint_thinking_mode_bypass(self):
|
| 62 |
+
"""emit_hint returns empty groups for thinking mode when bypass=True."""
|
| 63 |
+
config = CLAGroupConfig(group_size=2, thinking_mode_bypass=True)
|
| 64 |
+
layer = CLAMetadataLayer(config)
|
| 65 |
+
hint = layer.emit_hint(
|
| 66 |
+
agent_id="agent1",
|
| 67 |
+
model_id="Qwen3.6-35B-A22B",
|
| 68 |
+
is_thinking_mode=True,
|
| 69 |
+
model_layer_count=32,
|
| 70 |
+
agent_role="critic",
|
| 71 |
+
)
|
| 72 |
+
assert len(hint.layer_groups) == 0
|
| 73 |
+
assert hint.estimated_vram_reduction_pct == 0.0
|
| 74 |
+
assert hint.is_thinking_mode == True
|
| 75 |
+
|
| 76 |
+
def test_estimated_vram_reduction(self):
|
| 77 |
+
"""estimated_vram_reduction returns correct fraction for group_size=2."""
|
| 78 |
+
config = CLAGroupConfig(group_size=2)
|
| 79 |
+
layer = CLAMetadataLayer(config)
|
| 80 |
+
groups = [(0, 1), (2, 3), (4, 5)]
|
| 81 |
+
reduction = layer.estimated_vram_reduction(groups)
|
| 82 |
+
assert reduction == 0.5 # (2-1)/2 = 0.5 → 50% VRAM reduction
|
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for EmbeddingEngine — TASK-001."""
|
| 2 |
+
import pytest
|
| 3 |
+
import numpy as np
|
| 4 |
+
from contextforge.embeddings.embedding_engine import EmbeddingEngine
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@pytest.fixture
|
| 8 |
+
async def engine():
|
| 9 |
+
"""Get EmbeddingEngine singleton."""
|
| 10 |
+
return await EmbeddingEngine.get_instance(dim=512, use_onnx=False)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestEmbeddingEngine:
|
| 14 |
+
"""Tests for EmbeddingEngine core functionality."""
|
| 15 |
+
|
| 16 |
+
@pytest.mark.asyncio
|
| 17 |
+
async def test_get_instance_returns_singleton(self, engine):
|
| 18 |
+
"""get_instance() returns the same instance on repeated calls."""
|
| 19 |
+
engine2 = await EmbeddingEngine.get_instance()
|
| 20 |
+
assert engine is engine2
|
| 21 |
+
|
| 22 |
+
@pytest.mark.asyncio
|
| 23 |
+
async def test_encode_returns_normalized_vector(self, engine):
|
| 24 |
+
"""encode() returns L2-normalized embedding."""
|
| 25 |
+
embedding = await engine.encode("test prompt")
|
| 26 |
+
assert isinstance(embedding, np.ndarray)
|
| 27 |
+
assert embedding.shape[0] == 512 # dim=512
|
| 28 |
+
norm = np.linalg.norm(embedding)
|
| 29 |
+
assert abs(norm - 1.0) < 1e-6
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_encode_batch_returns_list(self, engine):
|
| 33 |
+
"""encode_batch() returns list of embeddings."""
|
| 34 |
+
texts = ["prompt one", "prompt two", "prompt three"]
|
| 35 |
+
embeddings = await engine.encode_batch(texts)
|
| 36 |
+
assert isinstance(embeddings, list)
|
| 37 |
+
assert len(embeddings) == 3
|
| 38 |
+
for emb in embeddings:
|
| 39 |
+
assert isinstance(emb, np.ndarray)
|
| 40 |
+
assert emb.shape[0] == 512
|
| 41 |
+
|
| 42 |
+
@pytest.mark.asyncio
|
| 43 |
+
async def test_simhash_returns_int(self, engine):
|
| 44 |
+
"""simhash() returns 64-bit integer."""
|
| 45 |
+
token_ids = [101, 2003, 1996, 3007, 102]
|
| 46 |
+
h = await engine.simhash(token_ids)
|
| 47 |
+
assert isinstance(h, int)
|
| 48 |
+
assert h >= 0
|
| 49 |
+
|
| 50 |
+
@pytest.mark.asyncio
|
| 51 |
+
async def test_simhash_deterministic(self, engine):
|
| 52 |
+
"""simhash() is deterministic for same input."""
|
| 53 |
+
token_ids = [101, 2003, 1996, 3007, 102]
|
| 54 |
+
h1 = await engine.simhash(token_ids)
|
| 55 |
+
h2 = await engine.simhash(token_ids)
|
| 56 |
+
assert h1 == h2
|
| 57 |
+
|
| 58 |
+
@pytest.mark.asyncio
|
| 59 |
+
async def test_simhash_different_for_different_inputs(self, engine):
|
| 60 |
+
"""simhash() returns different values for different token sequences."""
|
| 61 |
+
h1 = await engine.simhash([101, 2003, 1996])
|
| 62 |
+
h2 = await engine.simhash([101, 3007, 102])
|
| 63 |
+
assert h1 != h2
|
| 64 |
+
|
| 65 |
+
@pytest.mark.asyncio
|
| 66 |
+
async def test_encode_caching(self, engine):
|
| 67 |
+
"""Identical text returns cached embedding."""
|
| 68 |
+
text = "shared system prompt"
|
| 69 |
+
e1 = await engine.encode(text)
|
| 70 |
+
e2 = await engine.encode(text)
|
| 71 |
+
assert np.allclose(e1, e2)
|
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for KVAwareRouter — TASK-009."""
|
| 2 |
+
import pytest
|
| 3 |
+
from contextforge.routing.kv_aware_router import KVAwareRouter, RouteDecision, WorkerState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestKVAwareRouter:
|
| 7 |
+
"""Tests for KV-aware routing."""
|
| 8 |
+
|
| 9 |
+
def test_register_worker(self):
|
| 10 |
+
"""register_worker() adds worker to routing mesh."""
|
| 11 |
+
router = KVAwareRouter(num_workers=2)
|
| 12 |
+
router.register_worker("worker_0")
|
| 13 |
+
stats = router.get_stats()
|
| 14 |
+
assert stats["num_workers"] == 1
|
| 15 |
+
|
| 16 |
+
def test_get_worker_for_anchor_unknown(self):
|
| 17 |
+
"""get_worker_for_anchor() returns None for unknown anchor."""
|
| 18 |
+
router = KVAwareRouter()
|
| 19 |
+
result = router.get_worker_for_anchor("unknown_anchor")
|
| 20 |
+
assert result is None
|
| 21 |
+
|
| 22 |
+
@pytest.mark.asyncio
|
| 23 |
+
async def test_select_worker_returns_route_decision(self):
|
| 24 |
+
"""select_worker() returns RouteDecision."""
|
| 25 |
+
router = KVAwareRouter(num_workers=2)
|
| 26 |
+
router.register_worker("worker_0")
|
| 27 |
+
router.register_worker("worker_1")
|
| 28 |
+
|
| 29 |
+
decision = await router.select_worker("anchor_hash", cla_group=1)
|
| 30 |
+
assert isinstance(decision, RouteDecision)
|
| 31 |
+
assert decision.anchor_hash == "anchor_hash"
|
| 32 |
+
assert decision.pre_rope == True # INVARIANT 10
|
| 33 |
+
|
| 34 |
+
@pytest.mark.asyncio
|
| 35 |
+
async def test_select_worker_anchor_locality(self):
|
| 36 |
+
"""Same anchor_hash routes to same worker (locality)."""
|
| 37 |
+
router = KVAwareRouter(num_workers=2, enable_anchor_locality=True)
|
| 38 |
+
router.register_worker("worker_0")
|
| 39 |
+
router.register_worker("worker_1")
|
| 40 |
+
|
| 41 |
+
d1 = await router.select_worker("anchor_x", cla_group=1)
|
| 42 |
+
d2 = await router.select_worker("anchor_x", cla_group=1)
|
| 43 |
+
# Both should route to same worker
|
| 44 |
+
assert d1.target_worker_id == d2.target_worker_id
|
| 45 |
+
|
| 46 |
+
@pytest.mark.asyncio
|
| 47 |
+
async def test_select_worker_load_balancing(self):
|
| 48 |
+
"""With no anchor history, routes to least loaded worker."""
|
| 49 |
+
router = KVAwareRouter(num_workers=3)
|
| 50 |
+
for i in range(3):
|
| 51 |
+
router.register_worker(f"worker_{i}")
|
| 52 |
+
|
| 53 |
+
decision = await router.select_worker("new_anchor", cla_group=None)
|
| 54 |
+
assert decision.target_worker_id.startswith("worker_")
|
| 55 |
+
|
| 56 |
+
@pytest.mark.asyncio
|
| 57 |
+
async def test_update_worker_state(self):
|
| 58 |
+
"""update_worker_state() updates worker load and CLA groups."""
|
| 59 |
+
router = KVAwareRouter(num_workers=2)
|
| 60 |
+
router.register_worker("worker_0")
|
| 61 |
+
|
| 62 |
+
await router.update_worker_state("worker_0", load=0.75, cla_group=2, workflow_step=5)
|
| 63 |
+
|
| 64 |
+
stats = router.get_stats()
|
| 65 |
+
assert stats["worker_loads"]["worker_0"]["load"] == 0.75
|
| 66 |
+
|
| 67 |
+
@pytest.mark.asyncio
|
| 68 |
+
async def test_broadcast_new_blocks(self):
|
| 69 |
+
"""broadcast_new_blocks() updates routing table."""
|
| 70 |
+
router = KVAwareRouter(num_workers=2)
|
| 71 |
+
router.register_worker("worker_0")
|
| 72 |
+
|
| 73 |
+
await router.broadcast_new_blocks("anchor_abc", ["b0", "b1"], "worker_0")
|
| 74 |
+
|
| 75 |
+
# Verify anchor now maps to worker
|
| 76 |
+
worker = router.get_worker_for_anchor("anchor_abc")
|
| 77 |
+
assert worker == "worker_0"
|
| 78 |
+
|
| 79 |
+
def test_get_stats_returns_worker_states(self):
|
| 80 |
+
"""get_stats() returns worker loads and CLA groups."""
|
| 81 |
+
router = KVAwareRouter(num_workers=2)
|
| 82 |
+
router.register_worker("worker_0")
|
| 83 |
+
router.register_worker("worker_1")
|
| 84 |
+
|
| 85 |
+
stats = router.get_stats()
|
| 86 |
+
assert "worker_loads" in stats
|
| 87 |
+
assert "worker_0" in stats["worker_loads"]
|
| 88 |
+
assert "worker_1" in stats["worker_loads"]
|
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for LMCacheConnectorV1 — TASK-007."""
|
| 2 |
+
import pytest
|
| 3 |
+
from contextforge.serving.lmcache_bridge import LMCacheConnectorV1, LMCacheMeta
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestLMCacheConnectorV1:
|
| 7 |
+
"""Tests for LMCache bridge."""
|
| 8 |
+
|
| 9 |
+
def test_lmcache_meta_defaults(self):
|
| 10 |
+
"""LMCacheMeta has pre_rope=True by default (INVARIANT 10)."""
|
| 11 |
+
meta = LMCacheMeta()
|
| 12 |
+
assert meta.pre_rope == True
|
| 13 |
+
|
| 14 |
+
def test_is_active_without_client(self):
|
| 15 |
+
"""is_active() returns False when no LMCache client."""
|
| 16 |
+
bridge = LMCacheConnectorV1(lmcache_client=None)
|
| 17 |
+
assert bridge.is_active() == False
|
| 18 |
+
|
| 19 |
+
def test_is_active_with_client(self):
|
| 20 |
+
"""is_active() returns True when LMCache client is provided."""
|
| 21 |
+
bridge = LMCacheConnectorV1(lmcache_client=object())
|
| 22 |
+
assert bridge.is_active() == True
|
| 23 |
+
|
| 24 |
+
def test_build_prefix_hint(self):
|
| 25 |
+
"""build_prefix_hint returns correct metadata dict."""
|
| 26 |
+
bridge = LMCacheConnectorV1()
|
| 27 |
+
hint = bridge.build_prefix_hint(
|
| 28 |
+
token_ids=[101, 2003, 1996],
|
| 29 |
+
agent_id="agent_1",
|
| 30 |
+
anchor_hash="anchor_abc",
|
| 31 |
+
)
|
| 32 |
+
assert hint["anchor_hash"] == "anchor_abc"
|
| 33 |
+
assert hint["agent_id"] == "agent_1"
|
| 34 |
+
assert hint["token_length"] == 3
|
| 35 |
+
assert hint["pre_rope"] == True # INVARIANT 10
|
| 36 |
+
|
| 37 |
+
@pytest.mark.asyncio
|
| 38 |
+
async def test_on_save_kv_layer_noop_when_inactive(self):
|
| 39 |
+
"""on_save_kv_layer does nothing when bridge is inactive."""
|
| 40 |
+
bridge = LMCacheConnectorV1(lmcache_client=None)
|
| 41 |
+
await bridge.on_save_kv_layer("block_0", None, {"anchor_hash": "test"})
|
| 42 |
+
# No error means graceful handling
|
| 43 |
+
|
| 44 |
+
@pytest.mark.asyncio
|
| 45 |
+
async def test_on_load_kv_layer_returns_none_when_inactive(self):
|
| 46 |
+
"""on_load_kv_layer returns None when bridge is inactive."""
|
| 47 |
+
bridge = LMCacheConnectorV1(lmcache_client=None)
|
| 48 |
+
result = await bridge.on_load_kv_layer("block_0", {"offset_hint": [1.0, 2.0]})
|
| 49 |
+
assert result is None
|
| 50 |
+
|
| 51 |
+
def test_get_stats_returns_dict(self):
|
| 52 |
+
"""get_stats returns bridge statistics."""
|
| 53 |
+
bridge = LMCacheConnectorV1(enable_offset_hints=True, enable_cla_metadata=False)
|
| 54 |
+
stats = bridge.get_stats()
|
| 55 |
+
assert isinstance(stats, dict)
|
| 56 |
+
assert stats["active"] == False
|
| 57 |
+
assert stats["offset_hints_enabled"] == True
|
| 58 |
+
assert stats["cla_metadata_enabled"] == False
|
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for PBKVPredictor — TASK-013."""
|
| 2 |
+
import pytest
|
| 3 |
+
import json
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from contextforge.scheduling.pbkv_predictor import PBKVPredictor, WorkflowStepRecord, PredictionResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestPBKVPredictor:
|
| 10 |
+
"""Tests for PBKV predictor stub."""
|
| 11 |
+
|
| 12 |
+
@pytest.mark.asyncio
|
| 13 |
+
async def test_log_workflow_step(self):
|
| 14 |
+
"""log_workflow_step() records steps in history and JSONL."""
|
| 15 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 16 |
+
predictor = PBKVPredictor(log_dir=tmpdir, max_history_steps=10)
|
| 17 |
+
|
| 18 |
+
await predictor.log_workflow_step(
|
| 19 |
+
step_idx=0,
|
| 20 |
+
agent_id="agent_retriever",
|
| 21 |
+
anchor_hash="anchor_0",
|
| 22 |
+
token_length=100,
|
| 23 |
+
cla_group=1,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
assert len(predictor._history) == 1
|
| 27 |
+
assert predictor._history[0].agent_id == "agent_retriever"
|
| 28 |
+
|
| 29 |
+
@pytest.mark.asyncio
|
| 30 |
+
async def test_predict_next_agents_returns_prediction_result(self):
|
| 31 |
+
"""predict_next_agents() returns PredictionResult."""
|
| 32 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 33 |
+
predictor = PBKVPredictor(log_dir=tmpdir, max_history_steps=10)
|
| 34 |
+
|
| 35 |
+
# Log some steps first
|
| 36 |
+
for i in range(5):
|
| 37 |
+
await predictor.log_workflow_step(
|
| 38 |
+
step_idx=i,
|
| 39 |
+
agent_id=f"agent_{i % 2}",
|
| 40 |
+
anchor_hash=f"anchor_{i}",
|
| 41 |
+
token_length=100,
|
| 42 |
+
cla_group=i % 2,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
result = await predictor.predict_next_agents("agent_0", current_step=3, num_predictions=2)
|
| 46 |
+
|
| 47 |
+
assert isinstance(result, PredictionResult)
|
| 48 |
+
assert isinstance(result.predicted_agents, list)
|
| 49 |
+
assert 0.0 <= result.confidence <= 1.0
|
| 50 |
+
|
| 51 |
+
@pytest.mark.asyncio
|
| 52 |
+
async def test_predict_next_agents_empty_history(self):
|
| 53 |
+
"""predict_next_agents() returns default when no history."""
|
| 54 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 55 |
+
predictor = PBKVPredictor(log_dir=tmpdir, max_history_steps=10)
|
| 56 |
+
|
| 57 |
+
result = await predictor.predict_next_agents("agent_0", current_step=0, num_predictions=3)
|
| 58 |
+
|
| 59 |
+
assert isinstance(result, PredictionResult)
|
| 60 |
+
# Empty history → confidence 0, returns current agent as fallback
|
| 61 |
+
assert result.confidence == 0.0
|
| 62 |
+
|
| 63 |
+
@pytest.mark.asyncio
|
| 64 |
+
async def test_get_prefetch_candidates(self):
|
| 65 |
+
"""get_prefetch_candidates() returns list of block IDs."""
|
| 66 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 67 |
+
predictor = PBKVPredictor(log_dir=tmpdir, max_history_steps=10)
|
| 68 |
+
|
| 69 |
+
for i in range(5):
|
| 70 |
+
await predictor.log_workflow_step(
|
| 71 |
+
step_idx=i,
|
| 72 |
+
agent_id=f"agent_{i % 2}",
|
| 73 |
+
anchor_hash=f"anchor_{i}",
|
| 74 |
+
token_length=100,
|
| 75 |
+
cla_group=i % 2,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
candidates = await predictor.get_prefetch_candidates("agent_0", step=3)
|
| 79 |
+
|
| 80 |
+
assert isinstance(candidates, list)
|
| 81 |
+
|
| 82 |
+
def test_workflow_step_record(self):
|
| 83 |
+
"""WorkflowStepRecord dataclass works."""
|
| 84 |
+
record = WorkflowStepRecord(
|
| 85 |
+
step_idx=0,
|
| 86 |
+
agent_id="test_agent",
|
| 87 |
+
anchor_hash="anchor_x",
|
| 88 |
+
token_length=100,
|
| 89 |
+
cla_group=2,
|
| 90 |
+
)
|
| 91 |
+
assert record.step_idx == 0
|
| 92 |
+
assert record.agent_id == "test_agent"
|
| 93 |
+
assert record.cla_group == 2
|
| 94 |
+
|
| 95 |
+
def test_prediction_result_defaults(self):
|
| 96 |
+
"""PredictionResult has correct defaults."""
|
| 97 |
+
result = PredictionResult(
|
| 98 |
+
predicted_agents=["a1"],
|
| 99 |
+
predicted_anchor_hashes=["h1"],
|
| 100 |
+
confidence=0.5,
|
| 101 |
+
)
|
| 102 |
+
assert result.prefetch_block_ids == []
|
| 103 |
+
assert result.confidence == 0.5
|
| 104 |
+
|
| 105 |
+
def test_get_stats(self):
|
| 106 |
+
"""get_stats() returns predictor statistics."""
|
| 107 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 108 |
+
predictor = PBKVPredictor(log_dir=tmpdir, max_history_steps=50)
|
| 109 |
+
|
| 110 |
+
stats = predictor.get_stats()
|
| 111 |
+
assert stats["history_size"] == 0
|
| 112 |
+
assert stats["max_history_steps"] == 50
|
| 113 |
+
assert "_pbkv_logs" in stats["log_file"]
|
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for RotateKVQuantizer — TASK-005."""
|
| 2 |
+
import pytest
|
| 3 |
+
import numpy as np
|
| 4 |
+
from contextforge.quantization.rotate_kv import RotateKVQuantizer, RotateKVConfig, QuantizedKVBlock
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestRotateKVQuantizer:
|
| 8 |
+
"""Tests for RotateKV quantization (INVARIANT 10: pre-RoPE only)."""
|
| 9 |
+
|
| 10 |
+
def test_rotate_kv_config_defaults(self):
|
| 11 |
+
"""RotateKVConfig has sensible defaults."""
|
| 12 |
+
config = RotateKVConfig()
|
| 13 |
+
assert config.bits == 4
|
| 14 |
+
assert config.group_size == 64
|
| 15 |
+
assert config.sink_tokens == 4
|
| 16 |
+
|
| 17 |
+
def test_quantized_kv_block_has_pre_rope_metadata(self):
|
| 18 |
+
"""QuantizedKVBlock stores pre_rope flag in metadata."""
|
| 19 |
+
# This tests the invariant: pre-RoPE tensors are what we quantize
|
| 20 |
+
block = QuantizedKVBlock(
|
| 21 |
+
keys_int4=np.zeros((10, 8, 64), dtype=np.float32),
|
| 22 |
+
values_int4=np.zeros((10, 8, 64), dtype=np.float32),
|
| 23 |
+
keys_sink_fp16=np.zeros((4, 8, 128), dtype=np.float16),
|
| 24 |
+
values_sink_fp16=np.zeros((4, 8, 128), dtype=np.float16),
|
| 25 |
+
scales_k=np.ones((1, 8, 64), dtype=np.float32),
|
| 26 |
+
zero_points_k=np.zeros((1, 8, 64), dtype=np.float32),
|
| 27 |
+
scales_v=np.ones((1, 8, 128), dtype=np.float32),
|
| 28 |
+
zero_points_v=np.zeros((1, 8, 128), dtype=np.float32),
|
| 29 |
+
channel_order=np.arange(128, dtype=np.int32),
|
| 30 |
+
positions=np.arange(14, dtype=np.float32),
|
| 31 |
+
bits=4,
|
| 32 |
+
)
|
| 33 |
+
assert block.bits == 4
|
| 34 |
+
|
| 35 |
+
@pytest.mark.asyncio
|
| 36 |
+
async def test_quantize_pre_rope_returns_quantized_block(self):
|
| 37 |
+
"""quantize_pre_rope() returns (QuantizedKVBlock, ndarray) tuple (INVARIANT 10)."""
|
| 38 |
+
config = RotateKVConfig(bits=4, group_size=64, sink_tokens=4)
|
| 39 |
+
quantizer = RotateKVQuantizer(config)
|
| 40 |
+
|
| 41 |
+
# Pre-RoPE tensors: (batch=1, seq_len, num_heads, head_dim)
|
| 42 |
+
k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 43 |
+
v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 44 |
+
positions = np.arange(64, dtype=np.float32)
|
| 45 |
+
|
| 46 |
+
result = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
|
| 47 |
+
assert isinstance(result, tuple)
|
| 48 |
+
qblock, remaining = result
|
| 49 |
+
assert isinstance(qblock, QuantizedKVBlock)
|
| 50 |
+
assert qblock.keys_int4.shape[0] > 0
|
| 51 |
+
assert qblock.values_int4.shape[0] > 0
|
| 52 |
+
|
| 53 |
+
@pytest.mark.asyncio
|
| 54 |
+
async def test_quantize_pre_rope_sink_tokens_preserved(self):
|
| 55 |
+
"""First sink_tokens are preserved at FP16."""
|
| 56 |
+
config = RotateKVConfig(bits=4, sink_tokens=4)
|
| 57 |
+
quantizer = RotateKVQuantizer(config)
|
| 58 |
+
|
| 59 |
+
k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 60 |
+
v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 61 |
+
positions = np.arange(64, dtype=np.float32)
|
| 62 |
+
|
| 63 |
+
qblock, _ = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
|
| 64 |
+
|
| 65 |
+
assert qblock.keys_sink_fp16.shape == (1, 4, 8, 64)
|
| 66 |
+
assert qblock.values_sink_fp16.shape == (1, 4, 8, 64)
|
| 67 |
+
|
| 68 |
+
@pytest.mark.asyncio
|
| 69 |
+
async def test_dequantize_returns_fp32_tensors(self):
|
| 70 |
+
"""dequantize() returns FP32 tensors."""
|
| 71 |
+
config = RotateKVConfig(bits=4, group_size=64, sink_tokens=4)
|
| 72 |
+
quantizer = RotateKVQuantizer(config)
|
| 73 |
+
|
| 74 |
+
k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 75 |
+
v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
|
| 76 |
+
positions = np.arange(64, dtype=np.float32)
|
| 77 |
+
|
| 78 |
+
qblock, _ = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
|
| 79 |
+
k_deq, v_deq = quantizer.dequantize(qblock)
|
| 80 |
+
|
| 81 |
+
assert isinstance(k_deq, np.ndarray)
|
| 82 |
+
assert isinstance(v_deq, np.ndarray)
|
| 83 |
+
assert k_deq.dtype == np.float32
|
| 84 |
+
assert v_deq.dtype == np.float32
|
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for AgentStepGraph — TASK-006."""
|
| 2 |
+
import pytest
|
| 3 |
+
import sys
|
| 4 |
+
from contextforge.scheduling.step_graph import AgentStepGraph, AgentStep
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestAgentStepGraph:
|
| 8 |
+
"""Tests for workflow step graph."""
|
| 9 |
+
|
| 10 |
+
@pytest.mark.asyncio
|
| 11 |
+
async def test_add_step_returns_self_for_chaining(self):
|
| 12 |
+
"""add_step() returns self for method chaining."""
|
| 13 |
+
graph = AgentStepGraph()
|
| 14 |
+
result = graph.add_step(AgentStep(agent_id="a", step_index=0))
|
| 15 |
+
assert result is graph
|
| 16 |
+
|
| 17 |
+
@pytest.mark.asyncio
|
| 18 |
+
async def test_compute_steps_to_execution_simple(self):
|
| 19 |
+
"""compute_steps_to_execution returns correct distance."""
|
| 20 |
+
graph = AgentStepGraph()
|
| 21 |
+
graph.add_step(AgentStep(agent_id="retriever", step_index=0))
|
| 22 |
+
graph.add_step(AgentStep(agent_id="summarizer", step_index=1, depends_on=["retriever"]))
|
| 23 |
+
graph.add_step(AgentStep(agent_id="critic", step_index=2, depends_on=["summarizer"]))
|
| 24 |
+
|
| 25 |
+
# retriever is at step 0, responder at step 2 (2 steps away from "retriever" start)
|
| 26 |
+
dist = graph.compute_steps_to_execution("critic", current_step=0)
|
| 27 |
+
assert dist >= 0
|
| 28 |
+
|
| 29 |
+
@pytest.mark.asyncio
|
| 30 |
+
async def test_compute_steps_to_execution_unknown_agent(self):
|
| 31 |
+
"""compute_steps_to_execution returns sys.maxsize for unknown agents."""
|
| 32 |
+
graph = AgentStepGraph()
|
| 33 |
+
graph.add_step(AgentStep(agent_id="retriever", step_index=0))
|
| 34 |
+
dist = graph.compute_steps_to_execution("unknown_agent", current_step=0)
|
| 35 |
+
assert dist == sys.maxsize
|
| 36 |
+
|
| 37 |
+
@pytest.mark.asyncio
|
| 38 |
+
async def test_get_prefetch_candidates(self):
|
| 39 |
+
"""get_prefetch_candidates returns agents within prefetch_window."""
|
| 40 |
+
graph = AgentStepGraph()
|
| 41 |
+
graph.add_step(AgentStep(agent_id="retriever", step_index=0))
|
| 42 |
+
graph.add_step(AgentStep(agent_id="summarizer", step_index=1, depends_on=["retriever"]))
|
| 43 |
+
graph.add_step(AgentStep(agent_id="critic", step_index=2, depends_on=["summarizer"]))
|
| 44 |
+
graph.add_step(AgentStep(agent_id="responder", step_index=3, depends_on=["critic"]))
|
| 45 |
+
|
| 46 |
+
candidates = graph.get_prefetch_candidates(current_step=0, lookahead=2)
|
| 47 |
+
assert isinstance(candidates, list)
|
| 48 |
+
|
| 49 |
+
@pytest.mark.asyncio
|
| 50 |
+
async def test_get_eviction_priority_order(self):
|
| 51 |
+
"""get_eviction_priority_order returns agents sorted by steps-to-execution."""
|
| 52 |
+
graph = AgentStepGraph()
|
| 53 |
+
graph.add_step(AgentStep(agent_id="retriever", step_index=0))
|
| 54 |
+
graph.add_step(AgentStep(agent_id="summarizer", step_index=1, depends_on=["retriever"]))
|
| 55 |
+
graph.add_step(AgentStep(agent_id="critic", step_index=2, depends_on=["summarizer"]))
|
| 56 |
+
|
| 57 |
+
order = graph.get_eviction_priority_order()
|
| 58 |
+
assert isinstance(order, list)
|
| 59 |
+
# "retriever" should be last (closest to execution), "critic" first (farthest)
|
| 60 |
+
if len(order) >= 2:
|
| 61 |
+
assert order[-1] == "retriever" # closest to execution
|
| 62 |
+
|
| 63 |
+
@pytest.mark.asyncio
|
| 64 |
+
async def test_validate_dag_detects_cycle(self):
|
| 65 |
+
"""validate_dag() raises ValueError on cycle."""
|
| 66 |
+
graph = AgentStepGraph()
|
| 67 |
+
graph.add_step(AgentStep(agent_id="a", step_index=0, depends_on=["b"]))
|
| 68 |
+
graph.add_step(AgentStep(agent_id="b", step_index=1, depends_on=["a"])) # cycle!
|
| 69 |
+
with pytest.raises(ValueError):
|
| 70 |
+
graph.validate_dag()
|
| 71 |
+
|
| 72 |
+
@pytest.mark.asyncio
|
| 73 |
+
async def test_validate_dag_accepts_valid_graph(self):
|
| 74 |
+
"""validate_dag() passes for valid DAG."""
|
| 75 |
+
graph = AgentStepGraph()
|
| 76 |
+
graph.add_step(AgentStep(agent_id="retriever", step_index=0))
|
| 77 |
+
graph.add_step(AgentStep(agent_id="summarizer", step_index=1, depends_on=["retriever"]))
|
| 78 |
+
graph.add_step(AgentStep(agent_id="critic", step_index=2, depends_on=["summarizer"]))
|
| 79 |
+
graph.validate_dag() # Should not raise
|