File size: 2,744 Bytes
16eaadc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Load the published OPTCG embeddings corpus from HF Hub.

Pulls `cards_with_embeddings.parquet` and `provenance.json` from the
configured dataset repo, applies the same numpy-array-to-list coercion
that the upstream CLI uses, and stacks the embedding column into a
single float32 matrix that downstream code reuses without restacking.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from optcg_cards.provenance import EmbedProvenance, read_provenance

logger = logging.getLogger(__name__)

REPO_ID = "t22000t/optcg-en-card-embeddings"
PARQUET_FILE = "cards_with_embeddings.parquet"
PROVENANCE_FILE = "provenance.json"


def load_corpus(
    token: str | None,
) -> tuple[list[dict[str, Any]], np.ndarray, EmbedProvenance, dict[str, int]]:
    """Return `(cards, matrix, embed_provenance, id_to_idx)` for the
    published embeddings corpus.

    The `embedding` column is stripped from `cards` after stacking into
    `matrix`. All list-typed columns are coerced to plain Python lists.
    The token is passed to `hf_hub_download` but never written to logs.
    """
    logger.info(
        "Loading corpus from %s (authenticated=%s)",
        REPO_ID,
        "yes" if token else "no",
    )
    parquet_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=PARQUET_FILE,
        repo_type="dataset",
        token=token,
    )
    prov_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=PROVENANCE_FILE,
        repo_type="dataset",
        token=token,
    )

    cards = _read_parquet_records(Path(parquet_path))
    if not cards:
        raise RuntimeError("Embeddings parquet returned 0 rows")

    matrix = np.stack(
        [np.asarray(c["embedding"], dtype=np.float32) for c in cards],
        axis=0,
    )

    for card in cards:
        card.pop("embedding", None)

    id_to_idx = {card["id"]: i for i, card in enumerate(cards)}

    _, embed_prov = read_provenance(Path(prov_path))
    if embed_prov is None:
        raise RuntimeError("Embeddings provenance is missing the `embed` block")

    return cards, matrix, embed_prov, id_to_idx


def _read_parquet_records(path: Path) -> list[dict[str, Any]]:
    # Mirrors the coercion loop in optcg_cards.cli._read_parquet
    # (cli.py:429-443). Pandas materializes list-typed parquet columns
    # as ndarrays; downstream code expects plain Python lists.
    df = pd.read_parquet(str(path))
    records = df.to_dict(orient="records")
    for record in records:
        for key, value in record.items():
            if isinstance(value, np.ndarray):
                record[key] = value.tolist()
    return records