"""TDD for spaceutil.cards.top_k_with_matrix. The local fast-path must rank identically to the upstream `top_k` from optcg_cards.search on the same inputs - that's the load-bearing invariant. The only difference is that top_k_with_matrix takes a pre-stacked matrix to avoid the per-call _vectors_to_array restack. """ from __future__ import annotations import numpy as np from optcg_cards.search import SearchHit, top_k def _matrix_from_cards(cards: list[dict]) -> np.ndarray: return np.stack( [np.asarray(c["embedding"], dtype=np.float32) for c in cards], axis=0, ) def test_top_k_with_matrix_matches_upstream(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = _matrix_from_cards(synthetic_cards) cards_no_emb = [{k: v for k, v in c.items() if k != "embedding"} for c in synthetic_cards] query = matrix[3] upstream = top_k(query, synthetic_cards, k=5) local = top_k_with_matrix(query, matrix, cards_no_emb, k=5) assert [h.card_id for h in local] == [h.card_id for h in upstream] np.testing.assert_allclose( [h.score for h in local], [h.score for h in upstream], atol=1e-6 ) def test_top_k_with_matrix_returns_searchhit(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = _matrix_from_cards(synthetic_cards) cards_no_emb = [{k: v for k, v in c.items() if k != "embedding"} for c in synthetic_cards] hits = top_k_with_matrix(matrix[0], matrix, cards_no_emb, k=3) assert all(isinstance(h, SearchHit) for h in hits) assert len(hits) == 3 assert all(h.rank == i + 1 for i, h in enumerate(hits)) def test_exclude_idx_removes_query_card(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = _matrix_from_cards(synthetic_cards) cards_no_emb = [{k: v for k, v in c.items() if k != "embedding"} for c in synthetic_cards] hits = top_k_with_matrix(matrix[7], matrix, cards_no_emb, k=5, exclude_idx=7) assert all(h.card_id != cards_no_emb[7]["id"] for h in hits) def test_k_larger_than_corpus_clamped(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = _matrix_from_cards(synthetic_cards) cards_no_emb = [{k: v for k, v in c.items() if k != "embedding"} for c in synthetic_cards] hits = top_k_with_matrix(matrix[0], matrix, cards_no_emb, k=100) assert len(hits) == 20 def test_empty_corpus_returns_empty(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = np.empty((0, 1024), dtype=np.float32) query = np.zeros(1024, dtype=np.float32) hits = top_k_with_matrix(query, matrix, [], k=10) assert hits == [] def test_scores_ordered_descending(synthetic_cards): from spaceutil.cards import top_k_with_matrix matrix = _matrix_from_cards(synthetic_cards) cards_no_emb = [{k: v for k, v in c.items() if k != "embedding"} for c in synthetic_cards] hits = top_k_with_matrix(matrix[0], matrix, cards_no_emb, k=10) scores = [h.score for h in hits] assert scores == sorted(scores, reverse=True)