"""Tests for AnchorPool KV offset estimation.""" import pytest import numpy as np from apohara_context_forge.kv_offset.anchor_pool import AnchorPool # ============================================================================= # Fixtures # ============================================================================= @pytest.fixture def sample_offset() -> np.ndarray: """Return a sample KV offset vector of shape (128,).""" return np.random.randn(128).astype(np.float32) @pytest.fixture def sample_kv_keys() -> np.ndarray: """Return sample KV keys with shape (seq_len=4, head_dim=128).""" np.random.seed(42) return np.random.randn(4, 128).astype(np.float32) @pytest.fixture def pool() -> AnchorPool: """Return a fresh AnchorPool instance.""" return AnchorPool(max_size=20) # ============================================================================= # predict_shareable() Tests # ============================================================================= @pytest.mark.asyncio async def test_predict_shareable_returns_true_for_high_similarity(pool, sample_offset): """Returns True when token sequence has high similarity with existing anchors.""" token_ids = [100, 200, 300, 400] agent_a = "agent-a" agent_b = "agent-b" await pool.update_pool(token_ids, agent_a, sample_offset) # Agent B has no offsets yet, but similarity should still be computed shareable = await pool.predict_shareable(token_ids, agent_b) assert isinstance(shareable, bool) @pytest.mark.asyncio async def test_predict_shareable_returns_false_when_pool_empty(pool): """Returns False when the anchor pool is empty.""" token_ids = [100, 200, 300] target_agent = "agent-xyz" result = await pool.predict_shareable(token_ids, target_agent) assert result is False @pytest.mark.asyncio async def test_predict_shareable_returns_false_when_target_not_in_offsets(pool, sample_offset): """Returns False when target_agent_id is not present in any anchor's offsets.""" token_ids = [100, 200, 300, 400] agent_a = "agent-a" agent_b = "agent-b" # Add anchor for agent-a only await pool.update_pool(token_ids, agent_a, sample_offset) # agent-b is not in any anchor's offsets shareable = await pool.predict_shareable(token_ids, agent_b) assert shareable is False # ============================================================================= # approximate_offset() Tests # ============================================================================= @pytest.mark.asyncio async def test_approximate_offset_returns_ndarray_when_candidates_exist(pool, sample_offset): """Returns np.ndarray when candidates exist for target_agent_id.""" token_ids = [100, 200, 300, 400] agent_a = "agent-a" await pool.update_pool(token_ids, agent_a, sample_offset) result = await pool.approximate_offset(token_ids, agent_a) assert result is not None assert isinstance(result.placeholder_offset, np.ndarray) assert result.placeholder_offset.shape == (128,) @pytest.mark.asyncio async def test_approximate_offset_returns_none_when_pool_empty(pool): """Returns None when the anchor pool is empty.""" token_ids = [100, 200, 300] target_agent = "agent-xyz" result = await pool.approximate_offset(token_ids, target_agent) assert result is None @pytest.mark.asyncio async def test_approximate_offset_weighted_interpolation_between_min_max(pool): """Weighted interpolation should produce values between min and max offsets.""" token_ids_base = [100, 200, 300, 400] agent_a = "agent-a" offset_low = np.full(128, 0.0, dtype=np.float32) offset_high = np.full(128, 1.0, dtype=np.float32) # Add two anchors with distinct offsets await pool.update_pool([100, 200, 300, 400], agent_a, offset_low) await pool.update_pool([101, 201, 301, 401], agent_a, offset_high) # Query with same base token IDs - should interpolate result = await pool.approximate_offset(token_ids_base, agent_a) assert result is not None assert np.all(result.placeholder_offset >= offset_low), "Result should be >= min offset" assert np.all(result.placeholder_offset <= offset_high), "Result should be <= max offset" # ============================================================================= # RoPE De-rotation Tests # ============================================================================= @pytest.mark.asyncio async def test_rope_derotation_differs_for_same_key_at_different_positions(pool, sample_kv_keys): """apply_rope_derotation() should produce different output for same key at different positions.""" key_at_pos0 = sample_kv_keys[0:1] # shape (1, 128) key_at_pos2 = sample_kv_keys[2:3] # shape (1, 128) derotated_0 = await pool.apply_rope_derotation(key_at_pos0, np.array([0])) derotated_2 = await pool.apply_rope_derotation(key_at_pos2, np.array([2])) assert not np.allclose(derotated_0, derotated_2), \ "De-rotated keys at different positions should differ" @pytest.mark.asyncio async def test_rope_derotation_produces_different_keys_for_off_position_tokens(pool): """ De-rotated keys at off-position indices should be more similar (lower cosine distance) than raw keys, because de-rotation aligns them to a common reference frame. Uses kv_keys shape (seq_len=4, head_dim=128) and positions [0, 1, 2, 3]. """ np.random.seed(123) kv_keys = np.random.randn(4, 128).astype(np.float32) positions = np.array([0, 1, 2, 3]) derotated = await pool.apply_rope_derotation(kv_keys, positions) # Compare position 0 vs position 2 (off-position) raw_key_0 = kv_keys[0] raw_key_2 = kv_keys[2] # Cosine similarity for raw keys raw_cos_sim = np.dot(raw_key_0, raw_key_2) / ( np.linalg.norm(raw_key_0) * np.linalg.norm(raw_key_2) ) # Cosine similarity for de-rotated keys derot_key_0 = derotated[0] derot_key_2 = derotated[2] derot_cos_sim = np.dot(derot_key_0, derot_key_2) / ( np.linalg.norm(derot_key_0) * np.linalg.norm(derot_key_2) ) # De-rotated keys at different positions should have higher cosine similarity # because de-rotation removes the position-dependent RoPE rotation assert derot_cos_sim > raw_cos_sim, \ f"De-rotated cosine similarity ({derot_cos_sim:.4f}) should be > raw ({raw_cos_sim:.4f})" @pytest.mark.asyncio async def test_rope_derotation_shape_preserved(pool, sample_kv_keys): """De-rotation should preserve the shape of kv_keys.""" positions = np.array([0, 1, 2, 3]) derotated = await pool.apply_rope_derotation(sample_kv_keys, positions) assert derotated.shape == sample_kv_keys.shape # ============================================================================= # Pool Pruning Tests # ============================================================================= @pytest.mark.asyncio async def test_pool_pruning_at_max_size_boundary(): """Pool size should be <= max_size after adding more anchors than max_size.""" pool = AnchorPool(max_size=5) # Add 8 anchors (more than max_size=5) for i in range(8): token_ids = [100 + i, 200 + i, 300 + i, 400 + i] agent_id = f"agent-{i % 3}" # Rotate through 3 agents offset = np.random.randn(128).astype(np.float32) await pool.update_pool(token_ids, agent_id, offset) stats = await pool.get_stats() assert stats["total_anchors"] <= 5, \ f"Pool size ({stats['total_anchors']}) should be <= max_size (5)" @pytest.mark.asyncio async def test_pool_pruning_evicts_least_frequently_used(): """Least-frequently-used anchors should be evicted first during pruning.""" pool = AnchorPool(max_size=5) # Add 5 anchors for agent-a token_ids_list = [ [100, 200, 300], [101, 201, 301], [102, 202, 302], [103, 203, 303], [104, 204, 304], ] for i, token_ids in enumerate(token_ids_list): offset = np.random.randn(128).astype(np.float32) await pool.update_pool(token_ids, "agent-a", offset) # Access first 3 anchors multiple times to increase their access_count for _ in range(3): await pool.predict_shareable(token_ids_list[0], "agent-b") await pool.predict_shareable(token_ids_list[1], "agent-b") await pool.predict_shareable(token_ids_list[2], "agent-b") # Add 3 more anchors to trigger pruning for i in range(3): token_ids = [110 + i, 210 + i, 310 + i] offset = np.random.randn(128).astype(np.float32) await pool.update_pool(token_ids, "agent-a", offset) # After pruning, the least-frequently-used (and oldest) anchors should be gone stats = await pool.get_stats() assert stats["total_anchors"] <= 5 # The first two anchors (with highest access_count due to 3x access) # should still exist, while others may have been evicted # We can't deterministically verify which specific ones remain without # inspecting internals, but we verify the pool respects max_size # ============================================================================= # get_stats() Tests # ============================================================================= @pytest.mark.asyncio async def test_get_stats_returns_correct_structure(pool, sample_offset): """get_stats() should return dict with expected keys and types.""" token_ids = [100, 200, 300, 400] agent_id = "agent-test" await pool.update_pool(token_ids, agent_id, sample_offset) stats = await pool.get_stats() assert "total_anchors" in stats assert "total_agent_offsets" in stats assert "agents_tracked" in stats assert "max_size" in stats assert isinstance(stats["total_anchors"], int) assert isinstance(stats["total_agent_offsets"], int) assert isinstance(stats["agents_tracked"], int) assert isinstance(stats["max_size"], int) assert stats["max_size"] == 20 @pytest.mark.asyncio async def test_get_stats_empty_pool(): """get_stats() should return zeros for an empty pool.""" pool = AnchorPool(max_size=10) stats = await pool.get_stats() assert stats["total_anchors"] == 0 assert stats["total_agent_offsets"] == 0 assert stats["agents_tracked"] == 0 assert stats["max_size"] == 10