File size: 3,587 Bytes
21626e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Cross-encoder reranker over ONNX Runtime (local, key-free)."""
import json

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

RERANK_REPO = "Xenova/bge-reranker-base"
RERANK_ONNX = "onnx/model_quantized.onnx"  # int8: ~3x faster on CPU, negligible quality loss
MAX_TOKENS = 512
_MAX_DOC_CHARS = 3000  # cap doc text before tokenizing (512 tokens ~ 2-3k chars)


class Reranker:
    """Cross-encoder that scores (query, section) pairs for relevance.

    Loads a BGE reranker as ONNX and runs it on CPU -- no API key; the model is
    downloaded once and cached. A cross-encoder reads the query and section
    jointly, so it judges true relevance far better than the BM25/embedding
    similarities used to build the candidate pool.
    """

    def __init__(self):
        model_path = hf_hub_download(RERANK_REPO, RERANK_ONNX)
        tok_path = hf_hub_download(RERANK_REPO, "tokenizer.json")
        cfg_path = hf_hub_download(RERANK_REPO, "config.json")
        with open(cfg_path, encoding="utf-8") as fh:
            self.pad_id = json.load(fh).get("pad_token_id", 0)

        self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
        self.input_names = {i.name for i in self.session.get_inputs()}

        self.tokenizer = Tokenizer.from_file(tok_path)
        # only_second keeps the query intact and truncates the (longer) section.
        self.tokenizer.enable_truncation(max_length=MAX_TOKENS, strategy="only_second")

    def score(self, query, documents):
        """Return a relevance logit for each document paired with the query.

        Higher means more relevant. The returned list is aligned with `documents`.
        """
        if not documents:
            return []
        encs = self.tokenizer.encode_batch(
            [(query, doc[:_MAX_DOC_CHARS]) for doc in documents])
        width = max(len(e.ids) for e in encs)
        input_ids = np.full((len(encs), width), self.pad_id, dtype=np.int64)
        attention = np.zeros((len(encs), width), dtype=np.int64)
        type_ids = np.zeros((len(encs), width), dtype=np.int64)
        for row, enc in enumerate(encs):
            n = len(enc.ids)
            input_ids[row, :n] = enc.ids
            attention[row, :n] = enc.attention_mask
            type_ids[row, :n] = enc.type_ids

        feed = {"input_ids": input_ids, "attention_mask": attention}
        if "token_type_ids" in self.input_names:
            feed["token_type_ids"] = type_ids
        logits = self.session.run(None, feed)[0]
        return np.asarray(logits, dtype=np.float32).reshape(-1).tolist()


def main():
    import time
    print(f"Loading {RERANK_REPO} ({RERANK_ONNX}) ...")
    reranker = Reranker()
    query = "powers of arrest without warrant"
    docs = [
        "Arrest without warrant. A peace officer may arrest without warrant a "
        "person who has committed a criminal offence.",
        "Definitions. In this Act, fish means any fish and includes shellfish, "
        "crustaceans and marine animals.",
        "Importation. It is prohibited to import cannabis except as authorized "
        "under this Act.",
    ]
    start = time.perf_counter()
    scores = reranker.score(query, docs)
    elapsed = (time.perf_counter() - start) * 1000
    print(f"\nQuery: {query!r}")
    for doc, score in sorted(zip(docs, scores), key=lambda x: x[1], reverse=True):
        print(f"  {score:8.3f}  {doc[:62]}")
    print(f"\n{elapsed:.0f} ms for {len(docs)} documents")


if __name__ == "__main__":
    main()