File size: 11,993 Bytes
234574a
 
6d9c72b
 
cf0a8ed
 
 
 
 
 
 
6d9c72b
 
 
234574a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
 
 
 
 
234574a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9c72b
234574a
 
 
 
 
 
 
 
 
 
6d9c72b
234574a
 
 
 
 
 
6d9c72b
234574a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""Tests for LSHTokenMatcher and FAISSContextIndex - v2.0 deduplication components."""
import numpy as np
import pytest

from apohara_context_forge.dedup.faiss_index import FAISSContextIndex, FAISSMatch
from apohara_context_forge.dedup.lsh_engine import LSHTokenMatcher, TokenBlockMatch

pytestmark = pytest.mark.skipif(
    not __import__('importlib').util.find_spec('faiss'),
    reason="faiss-cpu not installed — run: pip install faiss-cpu"
)


@pytest.fixture
def lsh_matcher():
    """Create a fresh LSHTokenMatcher for each test."""
    return LSHTokenMatcher()


@pytest.fixture
def faiss_index():
    """Create a fresh FAISSContextIndex for each test."""
    return FAISSContextIndex(dim=384)


class TestLSHTokenMatcher:
    """Tests for LSHTokenMatcher - token-level SimHash matching."""

    @pytest.mark.asyncio
    async def test_index_prompt(self, lsh_matcher):
        """Index a prompt, verify blocks are stored."""
        # Need >= block_size (16) tokens after tokenization. The Qwen3 BPE
        # collapses common English words to one token each, so a short
        # sentence may yield <16 tokens. Use a longer prompt to guarantee
        # at least one full block.
        text = (
            "This is a test prompt that should produce multiple token blocks "
            "for indexing across various transformer architectures including "
            "GPT, Llama, Qwen, and Mistral families on AMD MI300X with ROCm."
        )
        
        hashes = await lsh_matcher.index_prompt("agent1", text)
        
        # Verify blocks were indexed
        assert isinstance(hashes, list)
        
        # Check stats reflect the indexing
        stats = await lsh_matcher.stats()
        assert stats["total_blocks"] >= 1
        assert stats["total_agents"] == 1
        assert "agent1" in lsh_matcher._agent_blocks

    @pytest.mark.asyncio
    async def test_find_reusable_blocks(self, lsh_matcher):
        """Index one prompt, find matches in another with similar tokens."""
        # Index a prompt for agent1
        text1 = "You are a helpful assistant. You provide accurate and detailed responses."
        await lsh_matcher.index_prompt("agent1", text1)
        
        # Index another prompt for agent2 with identical beginning
        text2 = "You are a helpful assistant. Tell me about quantum physics."
        await lsh_matcher.index_prompt("agent2", text2)
        
        # Find reusable blocks in a new prompt with same prefix
        text3 = "You are a helpful assistant. What is machine learning?"
        matches = await lsh_matcher.find_reusable_blocks(text3)
        
        # Should find some matches since the prefix is the same
        assert isinstance(matches, list)
        # Matches should be sorted by hamming distance (best first)
        if len(matches) > 1:
            assert matches[0].hamming_distance <= matches[1].hamming_distance

    @pytest.mark.asyncio
    async def test_find_reusable_blocks_exclude_agent(self, lsh_matcher):
        """Verify exclude_agent parameter filters correctly."""
        text1 = "You are a helpful assistant. This is agent1's unique content here."
        await lsh_matcher.index_prompt("agent1", text1)
        
        text2 = "You are a helpful assistant. This is agent2's unique content here."
        await lsh_matcher.index_prompt("agent2", text2)
        
        # Search excluding agent1
        text3 = "You are a helpful assistant. This is agent1's unique content here."
        matches = await lsh_matcher.find_reusable_blocks(text3, exclude_agent="agent1")
        
        # Should not find any matches from agent1
        for match in matches:
            assert match.cached_agent_id != "agent1"

    @pytest.mark.asyncio
    async def test_get_shared_prefix_hash(self, lsh_matcher):
        """Compute stable hash of shared prefix."""
        text = "This is a test prompt for hashing."
        
        hash1 = await lsh_matcher.get_shared_prefix_hash(text)
        hash2 = await lsh_matcher.get_shared_prefix_hash(text)
        
        # Same text should produce same hash
        assert hash1 == hash2
        assert isinstance(hash1, str)
        assert len(hash1) == 32  # First 32 chars of SHA256

    @pytest.mark.asyncio
    async def test_get_shared_prefix_hash_different_texts(self, lsh_matcher):
        """Different texts should produce different hashes."""
        text1 = "Hello world"
        text2 = "Goodbye world"
        
        hash1 = await lsh_matcher.get_shared_prefix_hash(text1)
        hash2 = await lsh_matcher.get_shared_prefix_hash(text2)
        
        assert hash1 != hash2

    @pytest.mark.asyncio
    async def test_lsh_stats(self, lsh_matcher):
        """Verify index statistics."""
        text = "This is a test prompt that should produce multiple token blocks."
        await lsh_matcher.index_prompt("agent1", text)
        await lsh_matcher.index_prompt("agent2", text)
        
        stats = await lsh_matcher.stats()
        
        assert "total_blocks" in stats
        assert "total_agents" in stats
        assert "block_size" in stats
        assert "hash_bits" in stats
        assert "hamming_threshold" in stats
        
        assert stats["total_agents"] == 2
        assert stats["block_size"] == 16
        assert stats["hash_bits"] == 64

    @pytest.mark.asyncio
    async def test_clear_agent(self, lsh_matcher):
        """Remove all blocks for an agent."""
        text = "This is a test prompt for clearing agent blocks."
        await lsh_matcher.index_prompt("agent1", text)
        
        stats_before = await lsh_matcher.stats()
        assert stats_before["total_agents"] == 1
        
        removed_count = await lsh_matcher.clear_agent("agent1")
        
        assert removed_count >= 0
        stats_after = await lsh_matcher.stats()
        assert stats_after["total_agents"] == 0
        assert stats_after["total_blocks"] == 0

    @pytest.mark.asyncio
    async def test_clear_agent_not_found(self, lsh_matcher):
        """Clearing non-existent agent returns 0."""
        removed = await lsh_matcher.clear_agent("nonexistent")
        assert removed == 0


class TestFAISSContextIndex:
    """Tests for FAISSContextIndex - approximate nearest neighbor search."""

    @pytest.mark.asyncio
    async def test_add_and_search(self, faiss_index):
        """Add embeddings, search, verify matches above threshold."""
        # Add two agents with embeddings
        emb1 = np.random.randn(384).astype(np.float32)
        emb1 = emb1 / np.linalg.norm(emb1)  # Normalize
        
        emb2 = np.random.randn(384).astype(np.float32)
        emb2 = emb2 / np.linalg.norm(emb2)
        
        idx1 = await faiss_index.add("agent1", emb1.tolist())
        idx2 = await faiss_index.add("agent2", emb2.tolist())
        
        assert idx1 == 0
        assert idx2 == 1
        
        # Search with nearly identical query
        query = emb1.tolist()  # Same as agent1's embedding
        matches = await faiss_index.search(query, k=10, threshold=0.85)
        
        assert isinstance(matches, list)
        assert len(matches) >= 1
        
        # Best match should be agent1 (highest similarity to itself)
        best = matches[0]
        assert isinstance(best, FAISSMatch)
        assert best.agent_id == "agent1"
        assert best.similarity > 0.99

    @pytest.mark.asyncio
    async def test_search_with_threshold(self, faiss_index):
        """Verify threshold filtering works."""
        # Add an agent
        emb = np.random.randn(384).astype(np.float32)
        emb = emb / np.linalg.norm(emb)
        await faiss_index.add("agent1", emb.tolist())
        
        # Search with very different query
        random_query = np.random.randn(384).astype(np.float32)
        random_query = random_query / np.linalg.norm(random_query)
        
        # High threshold should filter out dissimilar results
        matches = await faiss_index.search(random_query.tolist(), k=5, threshold=0.99)
        
        # Should either be empty or only contain very high similarity matches
        for match in matches:
            assert match.similarity >= 0.99

    @pytest.mark.asyncio
    async def test_search_returns_sorted_by_similarity(self, faiss_index):
        """Verify results are sorted by descending similarity."""
        # Add multiple agents with different embeddings
        for i in range(5):
            emb = np.random.randn(384).astype(np.float32)
            emb = emb / np.linalg.norm(emb)
            await faiss_index.add(f"agent{i}", emb.tolist())
        
        # Search
        query = np.random.randn(384).astype(np.float32)
        query = query / np.linalg.norm(query)
        matches = await faiss_index.search(query, k=5, threshold=0.0)
        
        # Should be sorted by similarity descending
        if len(matches) > 1:
            for i in range(len(matches) - 1):
                assert matches[i].similarity >= matches[i + 1].similarity

    @pytest.mark.asyncio
    async def test_remove(self, faiss_index):
        """Remove agent from index."""
        emb = np.random.randn(384).astype(np.float32)
        emb = emb / np.linalg.norm(emb)
        await faiss_index.add("agent1", emb.tolist())
        
        assert faiss_index.size == 1
        
        removed = await faiss_index.remove("agent1")
        assert removed is True
        
        # Size stays the same (FAISS limitation), but agent should not be found
        assert faiss_index.size == 1

    @pytest.mark.asyncio
    async def test_remove_not_found(self, faiss_index):
        """Removing non-existent agent returns False."""
        removed = await faiss_index.remove("nonexistent")
        assert removed is False

    @pytest.mark.asyncio
    async def test_size(self, faiss_index):
        """Verify index size tracking."""
        assert faiss_index.size == 0
        
        emb = np.random.randn(384).astype(np.float32)
        emb = emb / np.linalg.norm(emb)
        
        await faiss_index.add("agent1", emb.tolist())
        assert faiss_index.size == 1
        
        await faiss_index.add("agent2", emb.tolist())
        assert faiss_index.size == 2
        
        await faiss_index.remove("agent1")
        assert faiss_index.size == 2  # FAISS doesn't actually remove

    @pytest.mark.asyncio
    async def test_multiple_searches(self, faiss_index):
        """Verify multiple searches work correctly."""
        # Add multiple agents
        embeddings = []
        for i in range(3):
            emb = np.random.randn(384).astype(np.float32)
            emb = emb / np.linalg.norm(emb)
            embeddings.append(emb)
            await faiss_index.add(f"agent{i}", emb.tolist())
        
        # Multiple searches should all work
        for emb in embeddings:
            matches = await faiss_index.search(emb.tolist(), k=3, threshold=0.5)
            assert len(matches) >= 1


class TestTokenBlockMatch:
    """Tests for TokenBlockMatch dataclass."""

    def test_token_block_match_creation(self):
        """Verify TokenBlockMatch has all required fields."""
        match = TokenBlockMatch(
            block_index=0,
            cached_block_hash=12345,
            hamming_distance=2,
            reuse_confidence=0.97,
            cached_agent_id="agent1"
        )
        
        assert match.block_index == 0
        assert match.cached_block_hash == 12345
        assert match.hamming_distance == 2
        assert match.reuse_confidence == 0.97
        assert match.cached_agent_id == "agent1"


class TestFAISSMatch:
    """Tests for FAISSMatch dataclass."""

    def test_faiss_match_creation(self):
        """Verify FAISSMatch has all required fields."""
        match = FAISSMatch(
            agent_id="agent1",
            similarity=0.95,
            index_position=5
        )
        
        assert match.agent_id == "agent1"
        assert match.similarity == 0.95
        assert match.index_position == 5