Spaces:
Sleeping
Sleeping
File size: 10,796 Bytes
234574a cf0a8ed 234574a 466cc3d 234574a 466cc3d 234574a 466cc3d 234574a 466cc3d 234574a 466cc3d 234574a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | """LSH Token-Level Matching Engine - IMPROVEMENT-001.
Token-level fuzzy matching using SimHash for KV cache block reuse.
Operates on actual token IDs from Qwen3 tokenizer, not word-level strings.
Aligns to vLLM PagedAttention block boundaries (default block_size=16).
Architecture:
Incoming prompt (text)
│
▼
Qwen3 Tokenizer ← Real token IDs, not word splits
│
▼
LSH Block Hashing ← SimHash on token blocks
│
▼
Block Alignment ← Align to PagedAttention blocks (16 tokens)
│
▼
Match Candidates ← Find blocks with hamming distance < threshold
│
▼
Reuse Decision → List of reusable block indices
Usage:
matcher = LSHTokenMatcher()
await matcher.index_prompt("agent1", "shared system prompt...")
matches = await matcher.find_reusable_blocks("new incoming prompt...")
"""
import asyncio
import hashlib
import logging
from dataclasses import dataclass
from typing import Optional
import numpy as np
from apohara_context_forge.token_counter import TokenCounter
logger = logging.getLogger(__name__)
# vLLM PagedAttention default block size
VLLM_BLOCK_SIZE = 16
@dataclass
class TokenBlockMatch:
"""A matching block found in the LSH index."""
block_index: int # Which block position in the new prompt
cached_block_hash: int # 64-bit SimHash of the matching cached block
hamming_distance: int # Lower = more similar (0 = identical)
reuse_confidence: float # 0.0-1.0 derived from hamming distance
cached_agent_id: str # Which agent owns the cached block
class LSHTokenMatcher:
"""
Token-level fuzzy matching using SimHash for KV cache block reuse.
Operates on actual token IDs from Qwen3 tokenizer.
Key insight: vLLM PagedAttention shares KV cache for identical token blocks.
Two prompts with 95% SBERT similarity but different wording may share ZERO cache.
LSH finds actual token-level matches at block boundaries.
Usage:
matcher = LSHTokenMatcher()
await matcher.index_prompt("agent1", system_prompt)
matches = await matcher.find_reusable_blocks(new_prompt)
"""
def __init__(
self,
block_size: int = VLLM_BLOCK_SIZE,
hash_bits: int = 64,
hamming_threshold: int = 8, # <8 bits different = high confidence
):
self._block_size = block_size
self._hash_bits = hash_bits
self._hamming_threshold = hamming_threshold
self._token_counter = TokenCounter.get()
# hash → list of (tokens, agent_id). A list (not a single tuple) so
# that multiple agents sharing the same prefix do not overwrite each
# other — the last writer would otherwise erase the earlier owners
# and `find_reusable_blocks` would miss legitimate cross-agent reuse.
self._block_store: dict[int, list[tuple[tuple[int, ...], str]]] = {}
self._agent_blocks: dict[str, list[int]] = {} # agent_id → list of block hashes
self._lock = asyncio.Lock()
@staticmethod
def _hamming(a: int, b: int) -> int:
"""Compute Hamming distance between two 64-bit integers."""
return bin(a ^ b).count('1')
async def index_prompt(
self,
agent_id: str,
text: str,
) -> list[int]:
"""
Tokenize, blockify, and index a prompt for future reuse.
Stores block hashes in LSH index.
Args:
agent_id: Owner of this prompt
text: Full prompt text
Returns:
List of block hashes that were indexed
"""
loop = asyncio.get_event_loop()
token_ids = await loop.run_in_executor(
None, self._token_counter.encode, text
)
hashes = []
blocks = []
# Create blocks aligned to vLLM PagedAttention boundaries
for i in range(0, len(token_ids), self._block_size):
block = tuple(token_ids[i:i + self._block_size])
# Skip partial blocks (no cache guarantee for < block_size)
if len(block) < self._block_size:
continue
block_hash = self._simhash_block(block)
owners = self._block_store.setdefault(block_hash, [])
# Avoid duplicating the same owner if index_prompt is called
# repeatedly for an agent (idempotent re-index).
if not any(aid == agent_id for _, aid in owners):
owners.append((block, agent_id))
hashes.append(block_hash)
blocks.append(block_hash)
async with self._lock:
self._agent_blocks[agent_id] = hashes
logger.debug(f"Indexed {len(hashes)} blocks for agent {agent_id}")
return hashes
async def find_reusable_blocks(
self,
text: str,
exclude_agent: Optional[str] = None,
) -> list[TokenBlockMatch]:
"""
Find cached blocks that can be reused for this prompt.
Args:
text: New prompt text
exclude_agent: Optionally exclude blocks from a specific agent
Returns:
List of TokenBlockMatch sorted by hamming distance (best first)
"""
loop = asyncio.get_event_loop()
token_ids = await loop.run_in_executor(
None, self._token_counter.encode, text
)
matches = []
for i in range(0, len(token_ids), self._block_size):
block = tuple(token_ids[i:i + self._block_size])
if len(block) < self._block_size:
continue
new_hash = self._simhash_block(block)
# Search for similar blocks. Each entry in the store may have
# multiple owners (agents that all indexed the same block).
# Exclusion matches both the bare agent_id ("agent1") and any
# role-suffixed variant ("agent1:system") because the registry
# indexes the system prompt under "<agent_id>:system" — without
# this an agent finds matches against its own system blocks and
# the cross-agent dedup path looks artificially busy.
exclude_prefix = f"{exclude_agent}:" if exclude_agent else None
for cached_hash, owners in self._block_store.items():
hd = self._hamming(new_hash, cached_hash)
if hd > self._hamming_threshold:
continue
confidence = 1.0 - (hd / self._hash_bits)
for cached_tokens, agent_id in owners:
if exclude_agent and (
agent_id == exclude_agent
or (exclude_prefix is not None and agent_id.startswith(exclude_prefix))
):
continue
matches.append(TokenBlockMatch(
block_index=i // self._block_size,
cached_block_hash=cached_hash,
hamming_distance=hd,
reuse_confidence=confidence,
cached_agent_id=agent_id,
))
# Sort by hamming distance (best = lowest)
matches.sort(key=lambda m: m.hamming_distance)
return matches
async def get_shared_prefix_hash(self, text: str) -> str:
"""
Compute a stable hash of the shared prefix (first block).
Used for routing hints to llm-d/vLLM.
Args:
text: Prompt text
Returns:
SHA256 hex string of first block's tokens
"""
loop = asyncio.get_event_loop()
token_ids = await loop.run_in_executor(
None, self._token_counter.encode, text
)
if len(token_ids) < self._block_size:
first_block = token_ids
else:
first_block = token_ids[:self._block_size]
# Create deterministic hash
hash_input = str(tuple(first_block)).encode()
return hashlib.sha256(hash_input).hexdigest()[:32] # First 32 chars
def _simhash_block(self, token_ids: tuple[int, ...]) -> int:
"""
Compute 64-bit SimHash fingerprint for a token block.
Uses stable pseudo-random projection per token ID.
Deterministic: same block always produces same hash.
Args:
token_ids: Tuple of token IDs
Returns:
64-bit integer hash
"""
v = np.zeros(self._hash_bits, dtype=np.float32)
for tid in token_ids:
# Deterministic pseudo-random projection
# Using xorshift for speed (avoids numpy RNG object creation)
h = int(tid)
for _ in range(4): # Mix well
h ^= h << 13
h ^= h >> 7
h ^= h << 17
h = h & 0xFFFFFFFF
# Project onto hash bits
for bit in range(self._hash_bits):
if (h >> (bit % 32)) & 1:
v[bit] += 1
else:
v[bit] -= 1
# Binarize
bits = (v > 0).astype(np.uint8)
# Pack into int64
result = 0
for i, b in enumerate(bits):
result |= (int(b) << i)
return result
async def stats(self) -> dict:
"""Return index statistics."""
async with self._lock:
return {
"total_blocks": len(self._block_store),
"total_agents": len(self._agent_blocks),
"block_size": self._block_size,
"hash_bits": self._hash_bits,
"hamming_threshold": self._hamming_threshold,
}
async def clear_agent(self, agent_id: str) -> int:
"""
Remove all blocks indexed for an agent.
Args:
agent_id: Agent to clear
Returns:
Number of blocks removed
"""
async with self._lock:
hashes = self._agent_blocks.pop(agent_id, [])
for h in hashes:
owners = self._block_store.get(h)
if not owners:
continue
# Drop only this agent's entry; keep blocks shared with others.
self._block_store[h] = [
(toks, aid) for (toks, aid) in owners if aid != agent_id
]
if not self._block_store[h]:
del self._block_store[h]
return len(hashes) |