File size: 3,587 Bytes
21626e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | """Cross-encoder reranker over ONNX Runtime (local, key-free)."""
import json
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
RERANK_REPO = "Xenova/bge-reranker-base"
RERANK_ONNX = "onnx/model_quantized.onnx" # int8: ~3x faster on CPU, negligible quality loss
MAX_TOKENS = 512
_MAX_DOC_CHARS = 3000 # cap doc text before tokenizing (512 tokens ~ 2-3k chars)
class Reranker:
"""Cross-encoder that scores (query, section) pairs for relevance.
Loads a BGE reranker as ONNX and runs it on CPU -- no API key; the model is
downloaded once and cached. A cross-encoder reads the query and section
jointly, so it judges true relevance far better than the BM25/embedding
similarities used to build the candidate pool.
"""
def __init__(self):
model_path = hf_hub_download(RERANK_REPO, RERANK_ONNX)
tok_path = hf_hub_download(RERANK_REPO, "tokenizer.json")
cfg_path = hf_hub_download(RERANK_REPO, "config.json")
with open(cfg_path, encoding="utf-8") as fh:
self.pad_id = json.load(fh).get("pad_token_id", 0)
self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
self.input_names = {i.name for i in self.session.get_inputs()}
self.tokenizer = Tokenizer.from_file(tok_path)
# only_second keeps the query intact and truncates the (longer) section.
self.tokenizer.enable_truncation(max_length=MAX_TOKENS, strategy="only_second")
def score(self, query, documents):
"""Return a relevance logit for each document paired with the query.
Higher means more relevant. The returned list is aligned with `documents`.
"""
if not documents:
return []
encs = self.tokenizer.encode_batch(
[(query, doc[:_MAX_DOC_CHARS]) for doc in documents])
width = max(len(e.ids) for e in encs)
input_ids = np.full((len(encs), width), self.pad_id, dtype=np.int64)
attention = np.zeros((len(encs), width), dtype=np.int64)
type_ids = np.zeros((len(encs), width), dtype=np.int64)
for row, enc in enumerate(encs):
n = len(enc.ids)
input_ids[row, :n] = enc.ids
attention[row, :n] = enc.attention_mask
type_ids[row, :n] = enc.type_ids
feed = {"input_ids": input_ids, "attention_mask": attention}
if "token_type_ids" in self.input_names:
feed["token_type_ids"] = type_ids
logits = self.session.run(None, feed)[0]
return np.asarray(logits, dtype=np.float32).reshape(-1).tolist()
def main():
import time
print(f"Loading {RERANK_REPO} ({RERANK_ONNX}) ...")
reranker = Reranker()
query = "powers of arrest without warrant"
docs = [
"Arrest without warrant. A peace officer may arrest without warrant a "
"person who has committed a criminal offence.",
"Definitions. In this Act, fish means any fish and includes shellfish, "
"crustaceans and marine animals.",
"Importation. It is prohibited to import cannabis except as authorized "
"under this Act.",
]
start = time.perf_counter()
scores = reranker.score(query, docs)
elapsed = (time.perf_counter() - start) * 1000
print(f"\nQuery: {query!r}")
for doc, score in sorted(zip(docs, scores), key=lambda x: x[1], reverse=True):
print(f" {score:8.3f} {doc[:62]}")
print(f"\n{elapsed:.0f} ms for {len(docs)} documents")
if __name__ == "__main__":
main()
|