t22000t's picture
Initial commit: optcg-explorer Gradio Space
3ab07bd
"""Cosine top-k against a pre-stacked embedding matrix.
`top_k_with_matrix` mirrors `optcg_cards.search.top_k` (search.py:52-109)
but skips the per-call `_vectors_to_array` restack. The Space stacks
once at startup; doing it again on every search burns ~50 ms for no
reason.
The cards list passed here has the `embedding` key removed (it lives in
the matrix). Lookups go through the corpus index.
"""
from __future__ import annotations
from typing import Any
import numpy as np
from optcg_cards.search import SearchHit
def top_k_with_matrix(
query_vector: np.ndarray | list[float],
matrix: np.ndarray,
cards: list[dict[str, Any]],
k: int = 10,
exclude_idx: int | None = None,
) -> list[SearchHit]:
if matrix.shape[0] == 0 or not cards:
return []
query = np.asarray(query_vector, dtype=np.float32)
if query.shape[-1] != matrix.shape[1]:
raise ValueError(
f"Dimension mismatch: query has {query.shape[-1]}, "
f"matrix has {matrix.shape[1]}"
)
scores = matrix @ query
order = np.argsort(-scores)
hits: list[SearchHit] = []
rank = 0
for idx in order:
idx_i = int(idx)
if exclude_idx is not None and idx_i == exclude_idx:
continue
card = cards[idx_i]
rank += 1
hits.append(
SearchHit(
rank=rank,
card_id=str(card.get("id", "")),
name=str(card.get("name", "")),
score=float(scores[idx_i]),
metadata={
"card_type": card.get("card_type"),
"colors": card.get("colors"),
"set_code": card.get("set_code"),
"rarity": card.get("rarity"),
},
)
)
if rank >= k:
break
return hits