""" Tests for VisualKVCache implementation. """ import hashlib import time import numpy as np import pytest from apohara_context_forge.multimodal.visual_kv_cache import ( VisualKVCache, VisualEmbeddingBlock, VisualCacheResult, QueueingController, ) class TestComputeContentHash: """INV-13: content_hash is SHA256 of RAW bytes — never of embeddings.""" def test_sha256_of_raw_bytes(self): """Verify content_hash is SHA256 hexdigest of raw bytes.""" cache = VisualKVCache() raw_bytes = b"test_image_data_12345" expected_hash = hashlib.sha256(raw_bytes).hexdigest() result = cache.compute_content_hash(raw_bytes) assert result == expected_hash assert len(result) == 64 # SHA256 hexdigest length def test_different_bytes_different_hash(self): """Different raw bytes produce different hashes.""" cache = VisualKVCache() hash1 = cache.compute_content_hash(b"image1") hash2 = cache.compute_content_hash(b"image2") assert hash1 != hash2 def test_same_bytes_same_hash(self): """Identical bytes produce identical hashes (cache key invariance).""" cache = VisualKVCache() raw = b"identical_content" hash1 = cache.compute_content_hash(raw) hash2 = cache.compute_content_hash(raw) assert hash1 == hash2 class TestVisualKVCacheLookup: """O(1) lookup via dict keyed by content_hash.""" def test_lookup_miss_returns_none(self): """Cache miss returns None without error.""" cache = VisualKVCache() result = cache.lookup("nonexistent_hash_12345") assert result is None def test_lookup_hit_returns_block(self): """Cache hit returns VisualEmbeddingBlock.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) raw_bytes = b"test_image" content_hash = cache.compute_content_hash(raw_bytes) cache.store(content_hash, "image", embedding, resolution=(512, 512)) result = cache.lookup(content_hash) assert result is not None assert isinstance(result, VisualEmbeddingBlock) assert result.content_hash == content_hash assert result.modality == "image" def test_lookup_updates_access_count(self): """On hit, access_count is incremented.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) raw_bytes = b"test_image" content_hash = cache.compute_content_hash(raw_bytes) cache.store(content_hash, "image", embedding) # Capture access_count immediately after each lookup # All references point to same object, so we check the value progression cache.lookup(content_hash) count_after_first = cache.lookup(content_hash).access_count count_after_second = cache.lookup(content_hash).access_count count_after_third = cache.lookup(content_hash).access_count # After store: access_count = 0 # After 1st lookup (returns it): access_count = 1 # After 2nd lookup: access_count = 2 # After 3rd lookup: access_count = 3 assert count_after_first == 2 assert count_after_second == 3 assert count_after_third == 4 def test_lookup_moves_to_end_lru(self): """Lookup moves accessed item to end (most recently used).""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) h1 = cache.compute_content_hash(b"first") h2 = cache.compute_content_hash(b"second") cache.store(h1, "image", embedding) cache.store(h2, "image", embedding) # Access first entry cache.lookup(h1) # Evict should remove h1 (now LRU due to h2 being accessed after h1) # Note: With LFU within the OrderedDict, accessing h1 makes it MRU again # So eviction would still remove h2 (the older one with fewer accesses) # This is expected behavior - we track LRU position and access count separately class TestVisualKVCacheStore: """Store embeddings with LFU eviction.""" def test_store_returns_block(self): """Store returns the created VisualEmbeddingBlock.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) content_hash = cache.compute_content_hash(b"test") result = cache.store(content_hash, "image", embedding, resolution=(512, 512)) assert isinstance(result, VisualEmbeddingBlock) assert result.content_hash == content_hash assert result.modality == "image" assert result.resolution == (512, 512) assert result.encoder_model == "Qwen3-VL-235B-A22B-Instruct" def test_store_with_custom_encoder_model(self): """Store accepts custom encoder model name.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) result = cache.store( cache.compute_content_hash(b"test"), "image", embedding, encoder_model="InternVL3-78B", ) assert result.encoder_model == "InternVL3-78B" def test_store_multiple_modalities(self): """Store accepts different modalities.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) h_img = cache.compute_content_hash(b"image") h_aud = cache.compute_content_hash(b"audio") h_vid = cache.compute_content_hash(b"video") cache.store(h_img, "image", embedding) cache.store(h_aud, "audio", embedding) cache.store(h_vid, "video", embedding) img_block = cache.lookup(h_img) aud_block = cache.lookup(h_aud) vid_block = cache.lookup(h_vid) assert img_block is not None assert aud_block is not None assert vid_block is not None assert img_block.modality == "image" assert aud_block.modality == "audio" assert vid_block.modality == "video" def test_store_evicts_on_max_entries(self): """Store triggers LFU eviction when max_entries exceeded.""" cache = VisualKVCache(max_entries=3) embedding = np.random.randn(100, 512).astype(np.float32) hashes = [cache.compute_content_hash(f"entry_{i}".encode()) for i in range(5)] for h in hashes[:3]: cache.store(h, "image", embedding) assert len(cache._cache) == 3 # Add 4th entry - should evict one cache.store(hashes[3], "image", embedding) assert len(cache._cache) == 3 # First entry should be evicted (LFU) assert cache.lookup(hashes[0]) is None class TestVisualKVCacheEviction: """LRU/LFU eviction logic.""" def test_vram_eviction_respects_max(self): """Eviction ensures total vram stays within limit.""" # Create small cache with limited vram cache = VisualKVCache( max_entries=10, max_vram_bytes=1000, # 1KB limit ) # Each embedding is ~400 bytes (100 * 512 * 4 / 512 estimate) # Use smaller embeddings to fit test embedding = np.random.randn(10, 10).astype(np.float32) # ~400 bytes # Store until vram limit triggers eviction stored_hashes = [] for i in range(20): h = cache.compute_content_hash(f"entry_{i}".encode()) cache.store(h, "image", embedding) stored_hashes.append(h) # Some entries should remain remaining = sum(1 for h in stored_hashes if cache.lookup(h) is not None) assert remaining > 0 assert remaining < len(stored_hashes) class TestQueueingControllerIntegration: """INV-11: With queueing_controller, visual eviction respects minimum_stable_blocks.""" def test_eviction_skipped_when_at_min_stable_blocks(self): """Eviction does not occur when cache size <= minimum_stable_blocks.""" class MockQueueingController(QueueingController): def __init__(self): self.minimum_stable_blocks = 2 def get_minimum_stable_blocks(self) -> int: return self.minimum_stable_blocks controller = MockQueueingController() cache = VisualKVCache( max_entries=10, queueing_controller=controller, ) embedding = np.random.randn(100, 512).astype(np.float32) # Store 2 entries (at minimum_stable_blocks) h1 = cache.compute_content_hash(b"entry1") h2 = cache.compute_content_hash(b"entry2") cache.store(h1, "image", embedding) cache.store(h2, "image", embedding) # Try to add 3rd - eviction should be skipped due to minimum_stable_blocks # The cache will still have 2 entries (or possibly 3 if no eviction happens) # But we should not evict below minimum_stable_blocks h3 = cache.compute_content_hash(b"entry3") cache.store(h3, "image", embedding) # Both original entries should still be accessible # (eviction was skipped) assert cache.lookup(h1) is not None or cache.lookup(h2) is not None def test_eviction_proceeds_above_min_stable_blocks(self): """Eviction proceeds normally when above minimum_stable_blocks.""" class MockQueueingController(QueueingController): def get_minimum_stable_blocks(self) -> int: return 1 cache = VisualKVCache( max_entries=3, queueing_controller=MockQueueingController(), ) embedding = np.random.randn(100, 512).astype(np.float32) hashes = [cache.compute_content_hash(f"entry_{i}".encode()) for i in range(5)] for h in hashes: cache.store(h, "image", embedding) # Should have evicted some entries assert len(cache._cache) <= 3 class TestDPModeRecommendation: """Batch-level DP hint based on AMD ROCm benchmarks.""" def test_dp_mode_recommended_batch_gte_2(self): """DP mode recommended when batch_image_count >= 2.""" cache = VisualKVCache() assert cache.get_dp_mode_recommendation(batch_image_count=2) is True assert cache.get_dp_mode_recommendation(batch_image_count=5) is True assert cache.get_dp_mode_recommendation(batch_image_count=9) is True def test_dp_mode_recommended_high_resolution(self): """DP mode recommended when resolution >= (512, 512).""" cache = VisualKVCache() assert cache.get_dp_mode_recommendation( batch_image_count=1, image_resolution=(512, 512) ) is True assert cache.get_dp_mode_recommendation( batch_image_count=1, image_resolution=(1024, 1024) ) is True def test_dp_mode_recommended_deep_encoder(self): """DP mode recommended when encoder_depth >= 45 (InternVL).""" cache = VisualKVCache() assert cache.get_dp_mode_recommendation( batch_image_count=1, encoder_depth=45 ) is True assert cache.get_dp_mode_recommendation( batch_image_count=1, encoder_depth=78 ) is True def test_dp_mode_not_recommended_small_batch_low_res(self): """DP mode not recommended for small batches with low resolution.""" cache = VisualKVCache() assert cache.get_dp_mode_recommendation( batch_image_count=1, image_resolution=(256, 256), encoder_depth=27 ) is False def test_dp_mode_not_recommended_large_batch_low_res(self): """DP mode not recommended when batch >= 10 AND resolution <= (256, 256).""" cache = VisualKVCache() assert cache.get_dp_mode_recommendation( batch_image_count=10, image_resolution=(256, 256) ) is False assert cache.get_dp_mode_recommendation( batch_image_count=15, image_resolution=(128, 128) ) is False def test_dp_mode_recommendation_increments_counter(self): """Calling get_dp_mode_recommendation increments internal counter.""" cache = VisualKVCache() cache.get_dp_mode_recommendation(batch_image_count=5) stats = cache.get_cache_stats() assert stats["dp_mode_recommendations"] == 1 class TestCacheStats: """Prometheus metrics via get_cache_stats().""" def test_stats_keys_complete(self): """All 6 Prometheus metric keys present.""" cache = VisualKVCache() stats = cache.get_cache_stats() expected_keys = { "visual_cache_hits", "visual_cache_misses", "visual_cache_hit_rate", "visual_vram_saved_bytes", "visual_cache_entries", "dp_mode_recommendations", } assert set(stats.keys()) == expected_keys def test_hit_rate_calculation(self): """Hit rate computed correctly.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) # Miss cache.lookup("nonexistent") # Hit h = cache.compute_content_hash(b"test") cache.store(h, "image", embedding) cache.lookup(h) stats = cache.get_cache_stats() assert stats["visual_cache_hits"] == 1 assert stats["visual_cache_misses"] == 1 assert stats["visual_cache_hit_rate"] == 0.5 def test_vram_saved_accumulates_on_hits(self): """VRAM saved bytes accumulates across hits.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) h = cache.compute_content_hash(b"test") cache.store(h, "image", embedding) # Multiple hits should accumulate vram_saved cache.lookup(h) cache.lookup(h) cache.lookup(h) stats = cache.get_cache_stats() assert stats["visual_vram_saved_bytes"] > 0 def test_entries_count(self): """visual_cache_entries reflects current cache size.""" cache = VisualKVCache(max_entries=10) embedding = np.random.randn(100, 512).astype(np.float32) for i in range(5): cache.store(cache.compute_content_hash(f"entry_{i}".encode()), "image", embedding) stats = cache.get_cache_stats() assert stats["visual_cache_entries"] == 5 class TestClear: """Cache clear functionality.""" def test_clear_resets_all_state(self): """Clear removes all entries and resets metrics.""" cache = VisualKVCache() embedding = np.random.randn(100, 512).astype(np.float32) h = cache.compute_content_hash(b"test") cache.store(h, "image", embedding) cache.lookup(h) cache.get_dp_mode_recommendation(batch_image_count=5) cache.clear() stats = cache.get_cache_stats() assert stats["visual_cache_entries"] == 0 assert stats["visual_cache_hits"] == 0 assert stats["visual_cache_misses"] == 0 assert stats["visual_vram_saved_bytes"] == 0 assert stats["dp_mode_recommendations"] == 0 assert cache.lookup(h) is None