"""Cross-encoder reranker over ONNX Runtime (local, key-free).""" import json import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download from tokenizers import Tokenizer RERANK_REPO = "Xenova/bge-reranker-base" RERANK_ONNX = "onnx/model_quantized.onnx" # int8: ~3x faster on CPU, negligible quality loss MAX_TOKENS = 512 _MAX_DOC_CHARS = 3000 # cap doc text before tokenizing (512 tokens ~ 2-3k chars) class Reranker: """Cross-encoder that scores (query, section) pairs for relevance. Loads a BGE reranker as ONNX and runs it on CPU -- no API key; the model is downloaded once and cached. A cross-encoder reads the query and section jointly, so it judges true relevance far better than the BM25/embedding similarities used to build the candidate pool. """ def __init__(self): model_path = hf_hub_download(RERANK_REPO, RERANK_ONNX) tok_path = hf_hub_download(RERANK_REPO, "tokenizer.json") cfg_path = hf_hub_download(RERANK_REPO, "config.json") with open(cfg_path, encoding="utf-8") as fh: self.pad_id = json.load(fh).get("pad_token_id", 0) self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) self.input_names = {i.name for i in self.session.get_inputs()} self.tokenizer = Tokenizer.from_file(tok_path) # only_second keeps the query intact and truncates the (longer) section. self.tokenizer.enable_truncation(max_length=MAX_TOKENS, strategy="only_second") def score(self, query, documents): """Return a relevance logit for each document paired with the query. Higher means more relevant. The returned list is aligned with `documents`. """ if not documents: return [] encs = self.tokenizer.encode_batch( [(query, doc[:_MAX_DOC_CHARS]) for doc in documents]) width = max(len(e.ids) for e in encs) input_ids = np.full((len(encs), width), self.pad_id, dtype=np.int64) attention = np.zeros((len(encs), width), dtype=np.int64) type_ids = np.zeros((len(encs), width), dtype=np.int64) for row, enc in enumerate(encs): n = len(enc.ids) input_ids[row, :n] = enc.ids attention[row, :n] = enc.attention_mask type_ids[row, :n] = enc.type_ids feed = {"input_ids": input_ids, "attention_mask": attention} if "token_type_ids" in self.input_names: feed["token_type_ids"] = type_ids logits = self.session.run(None, feed)[0] return np.asarray(logits, dtype=np.float32).reshape(-1).tolist() def main(): import time print(f"Loading {RERANK_REPO} ({RERANK_ONNX}) ...") reranker = Reranker() query = "powers of arrest without warrant" docs = [ "Arrest without warrant. A peace officer may arrest without warrant a " "person who has committed a criminal offence.", "Definitions. In this Act, fish means any fish and includes shellfish, " "crustaceans and marine animals.", "Importation. It is prohibited to import cannabis except as authorized " "under this Act.", ] start = time.perf_counter() scores = reranker.score(query, docs) elapsed = (time.perf_counter() - start) * 1000 print(f"\nQuery: {query!r}") for doc, score in sorted(zip(docs, scores), key=lambda x: x[1], reverse=True): print(f" {score:8.3f} {doc[:62]}") print(f"\n{elapsed:.0f} ms for {len(docs)} documents") if __name__ == "__main__": main()