Spaces:
Sleeping
Sleeping
File size: 11,993 Bytes
234574a 6d9c72b cf0a8ed 6d9c72b 234574a 466cc3d 234574a 6d9c72b 234574a 6d9c72b 234574a 6d9c72b 234574a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | """Tests for LSHTokenMatcher and FAISSContextIndex - v2.0 deduplication components."""
import numpy as np
import pytest
from apohara_context_forge.dedup.faiss_index import FAISSContextIndex, FAISSMatch
from apohara_context_forge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch
pytestmark = pytest.mark.skipif(
not __import__('importlib').util.find_spec('faiss'),
reason="faiss-cpu not installed — run: pip install faiss-cpu"
)
@pytest.fixture
def lsh_matcher():
"""Create a fresh LSHTokenMatcher for each test."""
return LSHTokenMatcher()
@pytest.fixture
def faiss_index():
"""Create a fresh FAISSContextIndex for each test."""
return FAISSContextIndex(dim=384)
class TestLSHTokenMatcher:
"""Tests for LSHTokenMatcher - token-level SimHash matching."""
@pytest.mark.asyncio
async def test_index_prompt(self, lsh_matcher):
"""Index a prompt, verify blocks are stored."""
# Need >= block_size (16) tokens after tokenization. The Qwen3 BPE
# collapses common English words to one token each, so a short
# sentence may yield <16 tokens. Use a longer prompt to guarantee
# at least one full block.
text = (
"This is a test prompt that should produce multiple token blocks "
"for indexing across various transformer architectures including "
"GPT, Llama, Qwen, and Mistral families on AMD MI300X with ROCm."
)
hashes = await lsh_matcher.index_prompt("agent1", text)
# Verify blocks were indexed
assert isinstance(hashes, list)
# Check stats reflect the indexing
stats = await lsh_matcher.stats()
assert stats["total_blocks"] >= 1
assert stats["total_agents"] == 1
assert "agent1" in lsh_matcher._agent_blocks
@pytest.mark.asyncio
async def test_find_reusable_blocks(self, lsh_matcher):
"""Index one prompt, find matches in another with similar tokens."""
# Index a prompt for agent1
text1 = "You are a helpful assistant. You provide accurate and detailed responses."
await lsh_matcher.index_prompt("agent1", text1)
# Index another prompt for agent2 with identical beginning
text2 = "You are a helpful assistant. Tell me about quantum physics."
await lsh_matcher.index_prompt("agent2", text2)
# Find reusable blocks in a new prompt with same prefix
text3 = "You are a helpful assistant. What is machine learning?"
matches = await lsh_matcher.find_reusable_blocks(text3)
# Should find some matches since the prefix is the same
assert isinstance(matches, list)
# Matches should be sorted by hamming distance (best first)
if len(matches) > 1:
assert matches[0].hamming_distance <= matches[1].hamming_distance
@pytest.mark.asyncio
async def test_find_reusable_blocks_exclude_agent(self, lsh_matcher):
"""Verify exclude_agent parameter filters correctly."""
text1 = "You are a helpful assistant. This is agent1's unique content here."
await lsh_matcher.index_prompt("agent1", text1)
text2 = "You are a helpful assistant. This is agent2's unique content here."
await lsh_matcher.index_prompt("agent2", text2)
# Search excluding agent1
text3 = "You are a helpful assistant. This is agent1's unique content here."
matches = await lsh_matcher.find_reusable_blocks(text3, exclude_agent="agent1")
# Should not find any matches from agent1
for match in matches:
assert match.cached_agent_id != "agent1"
@pytest.mark.asyncio
async def test_get_shared_prefix_hash(self, lsh_matcher):
"""Compute stable hash of shared prefix."""
text = "This is a test prompt for hashing."
hash1 = await lsh_matcher.get_shared_prefix_hash(text)
hash2 = await lsh_matcher.get_shared_prefix_hash(text)
# Same text should produce same hash
assert hash1 == hash2
assert isinstance(hash1, str)
assert len(hash1) == 32 # First 32 chars of SHA256
@pytest.mark.asyncio
async def test_get_shared_prefix_hash_different_texts(self, lsh_matcher):
"""Different texts should produce different hashes."""
text1 = "Hello world"
text2 = "Goodbye world"
hash1 = await lsh_matcher.get_shared_prefix_hash(text1)
hash2 = await lsh_matcher.get_shared_prefix_hash(text2)
assert hash1 != hash2
@pytest.mark.asyncio
async def test_lsh_stats(self, lsh_matcher):
"""Verify index statistics."""
text = "This is a test prompt that should produce multiple token blocks."
await lsh_matcher.index_prompt("agent1", text)
await lsh_matcher.index_prompt("agent2", text)
stats = await lsh_matcher.stats()
assert "total_blocks" in stats
assert "total_agents" in stats
assert "block_size" in stats
assert "hash_bits" in stats
assert "hamming_threshold" in stats
assert stats["total_agents"] == 2
assert stats["block_size"] == 16
assert stats["hash_bits"] == 64
@pytest.mark.asyncio
async def test_clear_agent(self, lsh_matcher):
"""Remove all blocks for an agent."""
text = "This is a test prompt for clearing agent blocks."
await lsh_matcher.index_prompt("agent1", text)
stats_before = await lsh_matcher.stats()
assert stats_before["total_agents"] == 1
removed_count = await lsh_matcher.clear_agent("agent1")
assert removed_count >= 0
stats_after = await lsh_matcher.stats()
assert stats_after["total_agents"] == 0
assert stats_after["total_blocks"] == 0
@pytest.mark.asyncio
async def test_clear_agent_not_found(self, lsh_matcher):
"""Clearing non-existent agent returns 0."""
removed = await lsh_matcher.clear_agent("nonexistent")
assert removed == 0
class TestFAISSContextIndex:
"""Tests for FAISSContextIndex - approximate nearest neighbor search."""
@pytest.mark.asyncio
async def test_add_and_search(self, faiss_index):
"""Add embeddings, search, verify matches above threshold."""
# Add two agents with embeddings
emb1 = np.random.randn(384).astype(np.float32)
emb1 = emb1 / np.linalg.norm(emb1) # Normalize
emb2 = np.random.randn(384).astype(np.float32)
emb2 = emb2 / np.linalg.norm(emb2)
idx1 = await faiss_index.add("agent1", emb1.tolist())
idx2 = await faiss_index.add("agent2", emb2.tolist())
assert idx1 == 0
assert idx2 == 1
# Search with nearly identical query
query = emb1.tolist() # Same as agent1's embedding
matches = await faiss_index.search(query, k=10, threshold=0.85)
assert isinstance(matches, list)
assert len(matches) >= 1
# Best match should be agent1 (highest similarity to itself)
best = matches[0]
assert isinstance(best, FAISSMatch)
assert best.agent_id == "agent1"
assert best.similarity > 0.99
@pytest.mark.asyncio
async def test_search_with_threshold(self, faiss_index):
"""Verify threshold filtering works."""
# Add an agent
emb = np.random.randn(384).astype(np.float32)
emb = emb / np.linalg.norm(emb)
await faiss_index.add("agent1", emb.tolist())
# Search with very different query
random_query = np.random.randn(384).astype(np.float32)
random_query = random_query / np.linalg.norm(random_query)
# High threshold should filter out dissimilar results
matches = await faiss_index.search(random_query.tolist(), k=5, threshold=0.99)
# Should either be empty or only contain very high similarity matches
for match in matches:
assert match.similarity >= 0.99
@pytest.mark.asyncio
async def test_search_returns_sorted_by_similarity(self, faiss_index):
"""Verify results are sorted by descending similarity."""
# Add multiple agents with different embeddings
for i in range(5):
emb = np.random.randn(384).astype(np.float32)
emb = emb / np.linalg.norm(emb)
await faiss_index.add(f"agent{i}", emb.tolist())
# Search
query = np.random.randn(384).astype(np.float32)
query = query / np.linalg.norm(query)
matches = await faiss_index.search(query, k=5, threshold=0.0)
# Should be sorted by similarity descending
if len(matches) > 1:
for i in range(len(matches) - 1):
assert matches[i].similarity >= matches[i + 1].similarity
@pytest.mark.asyncio
async def test_remove(self, faiss_index):
"""Remove agent from index."""
emb = np.random.randn(384).astype(np.float32)
emb = emb / np.linalg.norm(emb)
await faiss_index.add("agent1", emb.tolist())
assert faiss_index.size == 1
removed = await faiss_index.remove("agent1")
assert removed is True
# Size stays the same (FAISS limitation), but agent should not be found
assert faiss_index.size == 1
@pytest.mark.asyncio
async def test_remove_not_found(self, faiss_index):
"""Removing non-existent agent returns False."""
removed = await faiss_index.remove("nonexistent")
assert removed is False
@pytest.mark.asyncio
async def test_size(self, faiss_index):
"""Verify index size tracking."""
assert faiss_index.size == 0
emb = np.random.randn(384).astype(np.float32)
emb = emb / np.linalg.norm(emb)
await faiss_index.add("agent1", emb.tolist())
assert faiss_index.size == 1
await faiss_index.add("agent2", emb.tolist())
assert faiss_index.size == 2
await faiss_index.remove("agent1")
assert faiss_index.size == 2 # FAISS doesn't actually remove
@pytest.mark.asyncio
async def test_multiple_searches(self, faiss_index):
"""Verify multiple searches work correctly."""
# Add multiple agents
embeddings = []
for i in range(3):
emb = np.random.randn(384).astype(np.float32)
emb = emb / np.linalg.norm(emb)
embeddings.append(emb)
await faiss_index.add(f"agent{i}", emb.tolist())
# Multiple searches should all work
for emb in embeddings:
matches = await faiss_index.search(emb.tolist(), k=3, threshold=0.5)
assert len(matches) >= 1
class TestTokenBlockMatch:
"""Tests for TokenBlockMatch dataclass."""
def test_token_block_match_creation(self):
"""Verify TokenBlockMatch has all required fields."""
match = TokenBlockMatch(
block_index=0,
cached_block_hash=12345,
hamming_distance=2,
reuse_confidence=0.97,
cached_agent_id="agent1"
)
assert match.block_index == 0
assert match.cached_block_hash == 12345
assert match.hamming_distance == 2
assert match.reuse_confidence == 0.97
assert match.cached_agent_id == "agent1"
class TestFAISSMatch:
"""Tests for FAISSMatch dataclass."""
def test_faiss_match_creation(self):
"""Verify FAISSMatch has all required fields."""
match = FAISSMatch(
agent_id="agent1",
similarity=0.95,
index_position=5
)
assert match.agent_id == "agent1"
assert match.similarity == 0.95
assert match.index_position == 5
|