File size: 16,274 Bytes
24d9eca
 
cf0a8ed
24d9eca
 
 
 
 
 
cf0a8ed
 
 
 
24d9eca
 
 
 
 
 
 
cf0a8ed
24d9eca
 
 
 
466cc3d
 
 
 
 
 
 
 
 
24d9eca
466cc3d
24d9eca
466cc3d
 
24d9eca
 
 
 
 
 
 
 
cf0a8ed
24d9eca
cf0a8ed
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0a8ed
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0a8ed
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
 
 
 
 
 
 
 
 
24d9eca
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
 
 
24d9eca
 
 
466cc3d
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0a8ed
24d9eca
cf0a8ed
 
 
 
24d9eca
 
 
 
 
cf0a8ed
24d9eca
cf0a8ed
 
 
 
24d9eca
 
 
 
 
cf0a8ed
24d9eca
cf0a8ed
 
 
 
24d9eca
 
 
 
 
cf0a8ed
 
24d9eca
 
cf0a8ed
24d9eca
cf0a8ed
 
 
 
24d9eca
 
cf0a8ed
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
24d9eca
 
 
 
 
 
466cc3d
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0a8ed
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
"""End-to-end integration tests for ContextRegistry with LSH + FAISS + VRAMAwareCache."""
import asyncio
import importlib.util
import pytest
import pytest_asyncio
from unittest.mock import patch

from prometheus_client import REGISTRY

# Skip tests requiring faiss (not installed in this environment)
FAISS_AVAILABLE = importlib.util.find_spec('faiss') is not None

from apohara_context_forge import (
    ContextRegistry,
    SharedContextResult,
    LSHTokenMatcher,
    FAISSContextIndex,
    VRAMAwareCache,
    EvictionMode,
)
from apohara_context_forge.metrics.prometheus_metrics import cache_hits, cache_misses


@pytest_asyncio.fixture
async def registry():
    """Create a ContextRegistry with all components wired up.

    Two non-default knobs vs production:
      - FAISS index dim must match EmbeddingEngine output (512), otherwise
        faiss.IndexFlatIP.add() trips an assertion at runtime.
      - block_size=4 lets the short prompts in these tests produce at least
        one LSH block. Production runs at block_size=16 (vLLM PagedAttention
        page boundary) and uses much longer system prompts.
    """
    reg = ContextRegistry(
        lsh_matcher=LSHTokenMatcher(block_size=4),
        vram_cache=VRAMAwareCache(max_token_budget=50_000_000),
        faiss_index=FAISSContextIndex(dim=512),
        block_size=4,
    )
    await reg.start()
    yield reg
    await reg.stop()


class TestSharedContextWithSharedSystemPrompt:
    """Test 1: Register 3 agents with shared system prompt → get_shared_context()."""
    requires_faiss = pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")

    @pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")
    @pytest.mark.asyncio
    async def test_shared_system_prompt_returns_non_empty_blocks(self, registry):
        """Verify get_shared_context() returns non-empty blocks with tokens saved."""
        # Shared system prompt for all 3 agents
        system_prompt = (
            "You are a helpful AI assistant running on AMD MI300X. "
            "Your role is to provide accurate and concise responses."
        )

        role_prompt_1 = "You are a retriever agent specializing in finding relevant documents."
        role_prompt_2 = "You are a summarizer agent that condenses information."
        role_prompt_3 = "You are a translator agent that adapts content across languages."

        # Register all 3 agents with same system prompt
        entry1 = await registry.register_agent("agent1", system_prompt, role_prompt_1)
        assert entry1.agent_id == "agent1"
        assert entry1.token_count > 0

        entry2 = await registry.register_agent("agent2", system_prompt, role_prompt_2)
        assert entry2.agent_id == "agent2"
        assert entry2.token_count > 0

        entry3 = await registry.register_agent("agent3", system_prompt, role_prompt_3)
        assert entry3.agent_id == "agent3"
        assert entry3.token_count > 0

        # Get shared context across all 3 agents
        results = await registry.get_shared_context(["agent1", "agent2", "agent3"])

        # Verify result list is non-empty
        assert results is not None
        assert isinstance(results, list)

        # At least one result should have shared blocks (system prompt blocks should match)
        has_shared_blocks = any(
            len(r.shared_blocks) > 0 for r in results
        )

        # Verify total_tokens_saved > 0 if we found matches
        if has_shared_blocks:
            total_tokens_saved = sum(r.total_tokens_saved for r in results)
            assert total_tokens_saved > 0, "Expected token savings from shared blocks"

        # Verify reuse_confidence > 0 if we found matches
        if has_shared_blocks:
            max_confidence = max(r.reuse_confidence for r in results)
            assert max_confidence > 0.0, "Expected positive reuse confidence"

    @pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")
    @pytest.mark.asyncio
    async def test_shared_context_contains_all_requested_agents(self, registry):
        """Verify all requested agents are present in results."""
        system_prompt = "Shared system prompt for testing."

        await registry.register_agent("agent1", system_prompt, "Role 1")
        await registry.register_agent("agent2", system_prompt, "Role 2")
        await registry.register_agent("agent3", system_prompt, "Role 3")

        results = await registry.get_shared_context(["agent1", "agent2", "agent3"])

        result_agent_ids = {r.agent_id for r in results}
        assert result_agent_ids == {"agent1", "agent2", "agent3"}


@pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")
class TestPrometheusMetricsEmission:
    """Test 2: Prometheus metrics are emitted after get_shared_context()."""

    @pytest.mark.asyncio
    async def test_cache_hits_metric_incremented(self, registry):
        """Verify cache_hits counter is incremented after get_shared_context()."""
        system_prompt = "Test system prompt for metrics verification."

        await registry.register_agent("agent1", system_prompt, "Role 1")
        await registry.register_agent("agent2", system_prompt, "Role 2")

        # Clear any existing metrics by collecting samples
        initial_hits = self._get_metric_value(cache_hits, "agent1", "system_prompt")
        initial_misses = self._get_metric_value(cache_misses, "agent1")

        # Trigger get_shared_context
        await registry.get_shared_context(["agent1", "agent2"])

        # Verify cache_hits or cache_misses was incremented
        final_hits = self._get_metric_value(cache_hits, "agent1", "system_prompt")
        final_misses = self._get_metric_value(cache_misses, "agent1")

        metric_incremented = (
            (final_hits > initial_hits) or (final_misses > initial_misses)
        )
        assert metric_incremented, (
            f"Expected cache_hits or cache_misses to increment. "
            f"Hits: {initial_hits} -> {final_hits}, Misses: {initial_misses} -> {final_misses}"
        )

    @pytest.mark.asyncio
    async def test_cache_misses_metric_incremented_for_no_match(self, registry):
        """Verify cache_misses is incremented when no reusable blocks found."""
        # Use completely different prompts to ensure no matches
        # Use orthogonal token sets so the SimHash fingerprints land far
        # apart — anything sharing common token sequences (e.g. "prompt for
        # agent") collapses to similar hashes inside the hamming threshold.
        await registry.register_agent(
            "agent1",
            "Quantum chromodynamics describes strong nuclear interactions in baryons",
            "alpha beta gamma",
        )
        await registry.register_agent(
            "agent2",
            "Photosynthesis converts solar irradiance into glucose via chloroplast",
            "delta epsilon zeta",
        )

        initial_misses = self._get_metric_value(cache_misses, "agent1")

        # Get shared context - should have no matches due to different prompts
        await registry.get_shared_context(["agent1", "agent2"])

        final_misses = self._get_metric_value(cache_misses, "agent1")
        assert final_misses > initial_misses, "Expected cache_misses to increment for non-matching prompts"

    @staticmethod
    def _get_metric_value(counter, *label_values):
        """Get the current value of a Prometheus counter with given labels.

        Counters live as `<name>_total` samples in REGISTRY.collect(); we
        compare label values as a tuple (dict_values views never compare
        equal to a tuple under ==).
        """
        target = tuple(label_values)
        for metric_family in REGISTRY.collect():
            if metric_family.name == counter._name:
                for sample in metric_family.samples:
                    if tuple(sample.labels.values()) == target:
                        return sample.value
        return 0


class TestVRAMModeTransitions:
    """Test 3: VRAM mode transitions from RELAXED to higher modes under pressure."""

    @pytest.mark.asyncio
    async def test_mode_transitions_to_pressure_under_high_vram(self, registry):
        """Verify mode changes from RELAXED to PRESSURE when VRAM pressure increases."""
        # Initial mode should be RELAXED (no pressure)
        initial_mode = await registry.get_vram_mode()
        assert initial_mode == EvictionMode.RELAXED.value

        # Simulate VRAM pressure increase to PRESSURE level (0.85-0.92)
        await registry._vram_cache._apply_eviction_policy(pressure=0.88)

        current_mode = await registry.get_vram_mode()
        assert current_mode == EvictionMode.PRESSURE.value, (
            f"Expected PRESSURE mode at 0.88 pressure, got {current_mode}"
        )

    @pytest.mark.asyncio
    async def test_mode_transitions_to_critical_under_high_vram(self, registry):
        """Verify mode changes from RELAXED to CRITICAL when VRAM pressure is high."""
        # Simulate VRAM pressure increase to CRITICAL level (0.92-0.96)
        await registry._vram_cache._apply_eviction_policy(pressure=0.94)

        current_mode = await registry.get_vram_mode()
        assert current_mode == EvictionMode.CRITICAL.value, (
            f"Expected CRITICAL mode at 0.94 pressure, got {current_mode}"
        )

    @pytest.mark.asyncio
    async def test_mode_transitions_to_emergency_at_saturation(self, registry):
        """Verify mode changes to EMERGENCY when VRAM pressure >= 0.96."""
        # Simulate VRAM pressure at EMERGENCY level (>= 0.96)
        await registry._vram_cache._apply_eviction_policy(pressure=0.97)

        current_mode = await registry.get_vram_mode()
        assert current_mode == EvictionMode.EMERGENCY.value, (
            f"Expected EMERGENCY mode at 0.97 pressure, got {current_mode}"
        )

    @pytest.mark.asyncio
    async def test_mode_reverts_to_relaxed_when_pressure_drops(self, registry):
        """Verify mode reverts to RELAXED when VRAM pressure drops."""
        # First, set to a higher mode
        await registry._vram_cache._apply_eviction_policy(pressure=0.88)
        assert await registry.get_vram_mode() == EvictionMode.PRESSURE.value

        # Then drop pressure to RELAXED level
        await registry._vram_cache._apply_eviction_policy(pressure=0.50)

        current_mode = await registry.get_vram_mode()
        assert current_mode == EvictionMode.RELAXED.value, (
            f"Expected RELAXED mode after pressure drop, got {current_mode}"
        )


@pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")
class TestClearAgent:
    """Test 4: clear_agent() removes agent from registry."""

    @pytest.mark.asyncio
    async def test_clear_agent_removes_from_registry(self, registry):
        """Verify get_all_agents() no longer contains cleared agent."""
        system_prompt = "Test system prompt for clear operation."

        # Register agent
        await registry.register_agent("agent_to_clear", system_prompt, "Role prompt")

        # Verify agent is registered
        all_agents_before = await registry.get_all_agents()
        assert "agent_to_clear" in all_agents_before

        # Clear the agent
        cleared = await registry.clear_agent("agent_to_clear")
        assert cleared is True

        # Verify agent is no longer in registry
        all_agents_after = await registry.get_all_agents()
        assert "agent_to_clear" not in all_agents_after

    @pytest.mark.asyncio
    async def test_clear_nonexistent_agent_returns_false(self, registry):
        """Verify clearing non-existent agent returns False."""
        result = await registry.clear_agent("nonexistent_agent")
        assert result is False

    @pytest.mark.asyncio
    async def test_clear_agent_clears_from_all_stores(self, registry):
        """Verify agent is removed from LSH, FAISS, and cache after clear."""
        system_prompt = "Test system prompt for complete clearing."

        # Register agent
        await registry.register_agent("agent_to_clear", system_prompt, "Role prompt")

        # Verify agent exists in LSH blocks
        agent_blocks_before = registry._lsh._agent_blocks.get("agent_to_clear")
        assert agent_blocks_before is not None

        # Clear the agent
        await registry.clear_agent("agent_to_clear")

        # Verify agent is removed from LSH
        agent_blocks_after = registry._lsh._agent_blocks.get("agent_to_clear")
        assert agent_blocks_after is None

        # Verify agent is removed from FAISS
        faiss_embedding = await registry._faiss.get_embedding("agent_to_clear")
        assert faiss_embedding is None

        # Verify agent is removed from VRAM cache
        cache_val = await registry._vram_cache.get("context:agent_to_clear")
        assert cache_val is None

    @pytest.mark.asyncio
    async def test_multiple_agents_cleared_selectively(self, registry):
        """Verify only specified agent is cleared when clearing one of many."""
        system_prompt = "Shared system prompt."

        # Register multiple agents
        await registry.register_agent("agent1", system_prompt, "Role 1")
        await registry.register_agent("agent2", system_prompt, "Role 2")
        await registry.register_agent("agent3", system_prompt, "Role 3")

        # Clear only agent2
        await registry.clear_agent("agent2")

        # Verify only agent2 is removed
        all_agents = await registry.get_all_agents()
        assert "agent1" in all_agents
        assert "agent2" not in all_agents
        assert "agent3" in all_agents


@pytest.mark.skipif(not FAISS_AVAILABLE, reason="faiss not installed")
class TestEndToEndWorkflow:
    """Full end-to-end workflow tests combining all components."""

    @pytest.mark.asyncio
    async def test_full_workflow_register_query_clear(self, registry):
        """Complete workflow: register → query → verify metrics → clear."""
        system_prompt = (
            "You are an AI assistant on AMD MI300X. "
            "Provide accurate and helpful responses."
        )

        # Register agents with shared system prompt
        await registry.register_agent("retriever", system_prompt, "Find relevant docs")
        await registry.register_agent("summarizer", system_prompt, "Summarize content")
        await registry.register_agent("translator", system_prompt, "Translate content")

        # Query shared context
        results = await registry.get_shared_context(["retriever", "summarizer", "translator"])
        assert len(results) == 3

        # Verify metrics were emitted
        all_agents = {"retriever", "summarizer", "translator"}
        result_ids = {r.agent_id for r in results}
        assert result_ids == all_agents

        # Clear one agent
        cleared = await registry.clear_agent("summarizer")
        assert cleared is True

        # Verify remaining agents still work
        remaining = await registry.get_all_agents()
        assert "retriever" in remaining
        assert "translator" in remaining
        assert "summarizer" not in remaining

    @pytest.mark.asyncio
    async def test_shared_context_with_empty_role_prompts(self, registry):
        """Verify registration works with empty role prompts."""
        system_prompt = "System prompt only."

        # Register with empty role prompts
        await registry.register_agent("agent1", system_prompt, "")
        await registry.register_agent("agent2", system_prompt, "")

        results = await registry.get_shared_context(["agent1", "agent2"])
        assert len(results) == 2

    @pytest.mark.asyncio
    async def test_get_shared_context_with_single_agent_returns_empty(self, registry):
        """Verify get_shared_context returns empty list for single agent."""
        await registry.register_agent("solo_agent", "System", "Role")

        results = await registry.get_shared_context(["solo_agent"])
        assert results == []

    @pytest.mark.asyncio
    async def test_get_shared_context_with_unregistered_agent_returns_empty(self, registry):
        """Verify get_shared_context returns empty when agent not registered."""
        results = await registry.get_shared_context(["nonexistent"])
        assert results == []