File size: 18,015 Bytes
24d9eca
 
 
 
 
 
 
 
 
6d9c72b
 
 
24d9eca
 
6d9c72b
cf0a8ed
 
 
 
 
24d9eca
 
 
 
 
cf0a8ed
 
 
6d9c72b
 
 
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
 
24d9eca
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb7184
24d9eca
 
 
 
466cc3d
24d9eca
 
 
 
 
 
 
466cc3d
 
 
 
 
24d9eca
bfb7184
 
24d9eca
 
466cc3d
 
 
 
 
 
 
 
 
 
 
 
 
24d9eca
 
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb7184
 
 
 
 
 
 
 
 
 
 
 
 
 
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb7184
 
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
6d9c72b
24d9eca
6d9c72b
24d9eca
 
6d9c72b
 
24d9eca
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
6d9c72b
24d9eca
 
 
 
 
6d9c72b
24d9eca
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
 
bfb7184
 
 
 
24d9eca
bfb7184
24d9eca
 
 
 
 
 
 
 
bfb7184
 
 
466cc3d
bfb7184
 
 
 
 
 
466cc3d
bfb7184
 
 
 
 
 
24d9eca
 
 
 
 
bfb7184
 
 
 
 
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d9eca
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
24d9eca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""ContextRegistry v3.0 - Wired to LSH + FAISS + VRAMAwareCache.

Replaces the old Python-loop dedup and static TTLCache with:
- LSHTokenMatcher: SimHash on actual Qwen3 token IDs, PagedAttention block alignment
- FAISSContextIndex: O(log n) ANN search vs O(n) linear scan
- VRAMAwareCache: 5-mode LRU/LFU hybrid with VRAM-pressure-responsive eviction

Dependency injection - no hardcoded imports of stale modules.
"""
import asyncio
import hashlib
import logging
from dataclasses import dataclass, field
from typing import Any, Optional

from apohara_context_forge.dedup.faiss_index import FAISSContextIndex, FAISSMatch
from apohara_context_forge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch
from apohara_context_forge.embeddings.embedding_engine import EmbeddingEngine
from apohara_context_forge.kv_offset.anchor_pool import AnchorPool
from apohara_context_forge.metrics.prometheus_metrics import (
    cache_hits,
    cache_misses,
    cache_registry_size,
    cache_evictions_total,
)
from apohara_context_forge.models import ContextEntry, ContextMatch
from apohara_context_forge.registry.vram_aware_cache import VRAMAwareCache
from apohara_context_forge.token_counter import TokenCounter

logger = logging.getLogger(__name__)

# vLLM PagedAttention block size
VLLM_BLOCK_SIZE = 16


@dataclass
class SharedContextResult:
    """Result of get_shared_context() - contains reusable blocks with metadata."""
    agent_id: str
    shared_blocks: list[TokenBlockMatch]
    faiss_matches: list[FAISSMatch]
    total_tokens_saved: int
    reuse_confidence: float  # 0.0-1.0 weighted by hamming distance
    offset_hints: dict[str, list[float]] = field(default_factory=dict)  # agent_id -> offset vector


@dataclass
class RegisteredAgent:
    """Internal record of a registered agent."""
    agent_id: str
    system_prompt: str
    role_prompt: str
    token_count: int
    block_hashes: list[int]  # LSH block hashes for this agent


class ContextRegistry:
    """
    Production-grade context registry with LSH + FAISS + VRAM-aware cache.

    Usage:
        registry = ContextRegistry(
            lsh_matcher=LSHTokenMatcher(),
            vram_cache=VRAMAwareCache(max_token_budget=50_000_000),
            faiss_index=FAISSContextIndex(dim=384),
        )
        await registry.start()

        # Register agents with shared system prompt
        await registry.register_agent("agent1", system_prompt, "retriever role")
        await registry.register_agent("agent2", system_prompt, "summarizer role")

        # Query for reusable context across agents
        result = await registry.get_shared_context(["agent1", "agent2"])

        await registry.stop()

    Key design decisions:
    - Dependency injection for all core components (testable, swappable)
    - LSH operates on token IDs, not text - aligns to vLLM PagedAttention blocks
    - FAISS provides ANN candidates; LSH filters for actual token-level reuse
    - VRAMAwareCache manages eviction based on real GPU memory pressure
    """

    def __init__(
        self,
        lsh_matcher: Optional[LSHTokenMatcher] = None,
        vram_cache: Optional[VRAMAwareCache] = None,
        faiss_index: Optional[FAISSContextIndex] = None,
        token_counter: Optional[TokenCounter] = None,
        anchor_pool: Optional[AnchorPool] = None,
        vram_budget_tokens: int = 50_000_000,
        block_size: int = VLLM_BLOCK_SIZE,
        hamming_threshold: int = 8,
        faiss_nlist: int = 100,
        dedup: Any = None,
    ):
        # Dependency injection with lazy defaults
        self._lsh = lsh_matcher or LSHTokenMatcher(
            block_size=block_size,
            hamming_threshold=hamming_threshold,
        )
        self._vram_cache = vram_cache or VRAMAwareCache(max_token_budget=vram_budget_tokens)
        # FAISS index dim must match the EmbeddingEngine output dimension
        # (we instantiate EmbeddingEngine with dim=512 in register_agent).
        # A 384-dim default crashes faiss.IndexFlatIP.add() at runtime —
        # see the cascade of test_integration failures pre-fix.
        self._faiss = faiss_index or FAISSContextIndex(dim=512)
        self._token_counter = token_counter or TokenCounter.get()
        self._anchor_pool = anchor_pool or AnchorPool()
        self._embedding_engine: Optional[EmbeddingEngine] = None
        self._block_size = block_size

        # `dedup` is a hermetic-test escape hatch — when set, register() short-
        # circuits the LSH+FAISS+ANN heavy path and uses the provided engine
        # instead. The engine only needs `embed`, `similarity`,
        # `find_shared_prefix`, and `count_prefix_tokens` — see FakeDedupEngine
        # in tests/test_mcp_server.py for the contract.
        self._dedup = dedup

        # Lightweight in-memory store for `register(agent_id, context)`. This
        # is independent from `register_agent(...)` (which exercises the full
        # KV-aware pipeline) — it backs the simple MCP /tools/register_context
        # endpoint and the test_full_flow scenario.
        self._simple_entries: dict[str, ContextEntry] = {}

        # Internal state
        self._agents: dict[str, RegisteredAgent] = {}
        self._system_prompt_hash: Optional[str] = None
        self._lock = asyncio.Lock()
        self._started = False

    async def start(self) -> None:
        """Start background VRAM monitor and cache."""
        if self._started:
            return
        await self._vram_cache.start()
        self._started = True
        logger.info("ContextRegistry started with LSH+FAISS+VRAM cache")

    async def stop(self) -> None:
        """Stop background monitoring and flush cache."""
        if not self._started:
            return
        await self._vram_cache.stop()
        self._started = False
        logger.info("ContextRegistry stopped")

    async def register_agent(
        self,
        agent_id: str,
        system_prompt: str,
        role_prompt: str,
    ) -> ContextEntry:
        """
        Register an agent with tokenization and LSH indexing.

        Args:
            agent_id: Unique agent identifier
            system_prompt: Shared system prompt (must be byte-identical across agents)
            role_prompt: Agent-specific role/instruction text

        Returns:
            ContextEntry with accurate token count
        """
        loop = asyncio.get_event_loop()

        # Tokenize full context
        full_context = f"{system_prompt}\n\n{role_prompt}"
        token_ids = await loop.run_in_executor(
            None, self._token_counter.encode, full_context
        )
        token_count = len(token_ids)

        # Index system prompt for LSH (critical for prefix caching)
        system_block_hashes = await self._lsh.index_prompt(
            f"{agent_id}:system",
            system_prompt
        )

        # Index full prompt for cross-agent dedup
        full_block_hashes = await self._lsh.index_prompt(
            agent_id,
            full_context
        )

        # Generate real embedding via EmbeddingEngine (replaces pseudo-embedding)
        if self._embedding_engine is None:
            self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True)
        embedding = await self._embedding_engine.encode(full_context)

        # Update AnchorPool — use embedding as kv_offset_approx until
        # LMCacheConnectorV1 bridge (TASK-007) provides real KV offset vectors
        await self._anchor_pool.update_pool(
            token_ids=token_ids,
            agent_id=agent_id,
            real_kv_offset=embedding.copy(),
            neighbor_prefix_offset=None,  # populated by TASK-007
        )

        # Store in VRAM-aware cache
        cache_key = f"context:{agent_id}"
        cache_value = {
            "system_prompt": system_prompt,
            "role_prompt": role_prompt,
            "full_context": full_context,
            "token_ids": token_ids,
        }
        stored = await self._vram_cache.set(
            cache_key,
            cache_value,
            token_count=token_count,
        )
        if not stored:
            logger.warning(f"VRAM cache blocked registration for {agent_id}")

        # Add to FAISS index for ANN search
        # Use real embedding from EmbeddingEngine (replaces pseudo-embedding)
        await self._faiss.add(agent_id, embedding.tolist())

        # Track registered agent
        async with self._lock:
            # Validate system prompt consistency (byte-identical for vLLM prefix caching)
            if self._system_prompt_hash is None:
                self._system_prompt_hash = self._sha256_prefix(system_prompt)
            else:
                incoming_hash = self._sha256_prefix(system_prompt)
                if incoming_hash != self._system_prompt_hash:
                    logger.warning(
                        f"Agent {agent_id} has DIFFERENT system prompt hash. "
                        f"vLLM prefix caching will NOT work. "
                        f"Expected {self._system_prompt_hash[:16]}, got {incoming_hash[:16]}"
                    )

            self._agents[agent_id] = RegisteredAgent(
                agent_id=agent_id,
                system_prompt=system_prompt,
                role_prompt=role_prompt,
                token_count=token_count,
                block_hashes=full_block_hashes,
            )

        logger.debug(f"Registered agent {agent_id}, tokens={token_count}, blocks={len(full_block_hashes)}")

        return ContextEntry(
            agent_id=agent_id,
            context=full_context,
            token_count=token_count,
            compressed_token_count=None,
            ttl_seconds=0,  # VRAM cache handles TTL
        )

    async def get_shared_context(
        self,
        agent_ids: list[str],
        target_agent_id: Optional[str] = None,
    ) -> list[SharedContextResult]:
        """
        Query for reusable context across multiple agents.

        Uses FAISS ANN to find candidate matches, then LSH to validate
        actual token-level reuse at PagedAttention block granularity.

        Args:
            agent_ids: Agents whose context to search
            target_agent_id: Optional target for offset hints

        Returns:
            List of SharedContextResult sorted by reuse confidence
        """
        if len(agent_ids) < 2:
            return []

        # Gather all registered agents
        agents_to_search = []
        async with self._lock:
            for aid in agent_ids:
                if aid in self._agents:
                    agents_to_search.append(self._agents[aid])

        if not agents_to_search:
            return []

        results: list[SharedContextResult] = []

        # For each agent, find matches in other agents
        for agent in agents_to_search:
            # Get full context for LSH matching
            cache_key = f"context:{agent.agent_id}"
            cache_val = await self._vram_cache.get(cache_key)
            if not cache_val:
                continue

            full_context = cache_val["full_context"]
            system_prompt = cache_val["system_prompt"]

            # Find reusable blocks via LSH
            matches = await self._lsh.find_reusable_blocks(
                full_context,
                exclude_agent=agent.agent_id,
            )

            # Filter matches by hamming threshold and compute confidence
            valid_matches = []
            total_hamming = 0
            for match in matches:
                if match.hamming_distance <= self._lsh._hamming_threshold:
                    valid_matches.append(match)
                    total_hamming += match.hamming_distance

            if not valid_matches:
                cache_misses.labels(agent_id=agent.agent_id).inc()
                continue

            avg_hamming = total_hamming / len(valid_matches)
            reuse_confidence = 1.0 - (avg_hamming / self._lsh._hash_bits)

            # Get FAISS ANN candidates for the system prompt
            # Use real embedding from EmbeddingEngine (replaces pseudo-embedding)
            if self._embedding_engine is None:
                self._embedding_engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True)
            system_embedding = await self._embedding_engine.encode(system_prompt)
            faiss_matches = await self._faiss.search(
                system_embedding.tolist(),
                k=5,
                threshold=0.7,
            )

            # Compute total tokens saved
            blocks_per_match = len(valid_matches)
            tokens_saved = blocks_per_match * self._block_size * len(valid_matches)

            # AnchorPool shareability prediction
            is_shareable = await self._anchor_pool.predict_shareable(
                token_ids=cache_val["token_ids"],
                target_agent_id=target_agent_id or agent.agent_id,
            )

            offset_vector = None
            if is_shareable:
                offset_result = await self._anchor_pool.approximate_offset(
                    token_ids=cache_val["token_ids"],
                    target_agent_id=target_agent_id or agent.agent_id,
                )
                if offset_result is not None:
                    offset_vector = offset_result.placeholder_offset

            # Populate offset_hints — this field was ALWAYS empty in V3
            result = SharedContextResult(
                agent_id=agent.agent_id,
                shared_blocks=valid_matches,
                faiss_matches=faiss_matches,
                total_tokens_saved=tokens_saved,
                reuse_confidence=reuse_confidence,
            )
            if offset_vector is not None:
                result.offset_hints[agent.agent_id] = offset_vector.tolist()

            results.append(result)

            cache_hits.labels(
                agent_id=agent.agent_id,
                segment_type="system_prompt",
            ).inc()

        # Sort by reuse confidence descending
        results.sort(key=lambda r: r.reuse_confidence, reverse=True)
        return results

    async def get_agent_context(self, agent_id: str) -> Optional[str]:
        """Get the full context for an agent."""
        cache_key = f"context:{agent_id}"
        cache_val = await self._vram_cache.get(cache_key)
        if cache_val:
            return cache_val["full_context"]
        return None

    async def register(self, agent_id: str, context: str) -> ContextEntry:
        """Lightweight register used by the MCP /tools/register_context endpoint.

        This is intentionally separate from `register_agent(...)`, which also
        indexes the system prompt for cross-agent KV reuse. The MCP endpoint
        deals with single opaque contexts, so we tokenize via TokenCounter,
        keep a `ContextEntry` in `_simple_entries`, and stop there.
        """
        from datetime import datetime as _dt, timedelta as _td, timezone as _tz

        loop = asyncio.get_event_loop()
        try:
            token_count = await loop.run_in_executor(
                None, self._token_counter.count, context
            )
        except Exception:
            token_count = max(1, len(context.split()))

        now = _dt.now(_tz.utc)
        entry = ContextEntry(
            agent_id=agent_id,
            context=context,
            token_count=token_count,
            created_at=now,
            expires_at=now + _td(seconds=300),
        )
        async with self._lock:
            self._simple_entries[agent_id] = entry
        return entry

    async def clear(self) -> None:
        """Drop all simple-register state. Called by the MCP server lifespan
        on shutdown so a fresh process starts from a clean registry. We do
        NOT touch LSH/FAISS here — those have their own lifecycle hooks."""
        async with self._lock:
            self._simple_entries.clear()

    async def clear_agent(self, agent_id: str) -> bool:
        """Clear an agent's context from all stores."""
        async with self._lock:
            if agent_id not in self._agents:
                return False

        # Remove from LSH
        await self._lsh.clear_agent(agent_id)
        await self._lsh.clear_agent(f"{agent_id}:system")

        # Remove from FAISS
        await self._faiss.remove(agent_id)

        # Remove from VRAM cache
        cache_key = f"context:{agent_id}"
        await self._vram_cache.delete(cache_key)

        # Remove from agents dict
        async with self._lock:
            del self._agents[agent_id]

        cache_evictions_total.labels(reason="manual").inc()
        return True

    async def get_all_agents(self) -> list[str]:
        """Get list of all registered agent IDs."""
        async with self._lock:
            return list(self._agents.keys())

    async def get_vram_mode(self) -> str:
        """Get current VRAM eviction mode."""
        return self._vram_cache.mode.value

    async def get_vram_pressure(self) -> float:
        """Get current VRAM pressure (0.0-1.0)."""
        return self._vram_cache._vram.get_pressure()

    @staticmethod
    def _sha256_prefix(text: str) -> str:
        """SHA256 of text for prefix validation."""
        import hashlib
        return hashlib.sha256(text.encode()).hexdigest()

    @property
    def lsh_matcher(self) -> LSHTokenMatcher:
        """Direct access to LSH matcher for advanced queries."""
        return self._lsh

    @property
    def faiss_index(self) -> FAISSContextIndex:
        """Direct access to FAISS index for advanced queries."""
        return self._faiss

    @property
    def vram_cache(self) -> VRAMAwareCache:
        """Direct access to VRAM cache for advanced queries."""
        return self._vram_cache

    @property
    def registry_size(self) -> int:
        """Number of registered agents."""
        return len(self._agents)

    @property
    def is_started(self) -> bool:
        """Whether the registry is running."""
        return self._started