File size: 2,550 Bytes
3ab07bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Cached Qwen3-Embedding encoder for free-text queries.

The upstream `optcg_cards.embed.encode_query` (embed.py:172-202) loads
the model fresh on every call. That's fine for the CLI but unusable
inside a Gradio Space (~2-5 s of model construction per query on CPU).

This module wraps it with a module-level cache: the model is loaded
once on first use, warmed up with a single encode pass, and reused for
all subsequent queries. The task instruction and matryoshka truncation
still come from the `EmbedProvenance`, so the encoded query stays
comparable to the published vectors.

The deviation from CLAUDE.md's "lazy-import heavy deps" rule is
intentional: eager loading at module import lets HF Spaces' "Building"
indicator absorb the ~30-60 s cold-start, sparing the first user.
"""

from __future__ import annotations

import logging
from typing import Any

import numpy as np
from optcg_cards.provenance import EmbedProvenance

logger = logging.getLogger(__name__)

_model_cache: dict[str, Any] = {}


def get_encoder(embed_prov: EmbedProvenance):
    """Load + cache a SentenceTransformer for the given model id."""
    key = embed_prov.model_id
    cached = _model_cache.get(key)
    if cached is not None:
        return cached

    from sentence_transformers import SentenceTransformer

    logger.info("Loading encoder %s (first use; subsequent calls reuse)", key)
    model = SentenceTransformer(key)
    # Warmup: first encode pass is always slower (graph build + kernel
    # selection). Pay it now so the first user query doesn't.
    _ = model.encode(["warmup"], normalize_embeddings=True, show_progress_bar=False)
    _model_cache[key] = model
    return model


def encode_query_via_optcg(query: str, embed_prov: EmbedProvenance) -> np.ndarray:
    """Encode a free-text query using the cached model.

    Mirrors the logic of optcg_cards.embed.encode_query (embed.py:172-202)
    so query vectors are comparable to the published corpus, but reuses
    the cached model instead of re-instantiating SentenceTransformer.
    """
    model = get_encoder(embed_prov)
    text = embed_prov.task_instruction.format(card_document=query)
    vector = model.encode(
        [text],
        normalize_embeddings=True,
        show_progress_bar=False,
        convert_to_numpy=True,
    )[0]

    if embed_prov.matryoshka_dim is not None:
        vector = vector[: embed_prov.matryoshka_dim]
        norm = float(np.linalg.norm(vector))
        if norm > 0:
            vector = vector / norm

    return np.asarray(vector, dtype=np.float32)