| """Cross-encoder reranker over ONNX Runtime (local, key-free).""" |
| import json |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from huggingface_hub import hf_hub_download |
| from tokenizers import Tokenizer |
|
|
| RERANK_REPO = "Xenova/bge-reranker-base" |
| RERANK_ONNX = "onnx/model_quantized.onnx" |
| MAX_TOKENS = 512 |
| _MAX_DOC_CHARS = 3000 |
|
|
|
|
| class Reranker: |
| """Cross-encoder that scores (query, section) pairs for relevance. |
| |
| Loads a BGE reranker as ONNX and runs it on CPU -- no API key; the model is |
| downloaded once and cached. A cross-encoder reads the query and section |
| jointly, so it judges true relevance far better than the BM25/embedding |
| similarities used to build the candidate pool. |
| """ |
|
|
| def __init__(self): |
| model_path = hf_hub_download(RERANK_REPO, RERANK_ONNX) |
| tok_path = hf_hub_download(RERANK_REPO, "tokenizer.json") |
| cfg_path = hf_hub_download(RERANK_REPO, "config.json") |
| with open(cfg_path, encoding="utf-8") as fh: |
| self.pad_id = json.load(fh).get("pad_token_id", 0) |
|
|
| self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
| self.input_names = {i.name for i in self.session.get_inputs()} |
|
|
| self.tokenizer = Tokenizer.from_file(tok_path) |
| |
| self.tokenizer.enable_truncation(max_length=MAX_TOKENS, strategy="only_second") |
|
|
| def score(self, query, documents): |
| """Return a relevance logit for each document paired with the query. |
| |
| Higher means more relevant. The returned list is aligned with `documents`. |
| """ |
| if not documents: |
| return [] |
| encs = self.tokenizer.encode_batch( |
| [(query, doc[:_MAX_DOC_CHARS]) for doc in documents]) |
| width = max(len(e.ids) for e in encs) |
| input_ids = np.full((len(encs), width), self.pad_id, dtype=np.int64) |
| attention = np.zeros((len(encs), width), dtype=np.int64) |
| type_ids = np.zeros((len(encs), width), dtype=np.int64) |
| for row, enc in enumerate(encs): |
| n = len(enc.ids) |
| input_ids[row, :n] = enc.ids |
| attention[row, :n] = enc.attention_mask |
| type_ids[row, :n] = enc.type_ids |
|
|
| feed = {"input_ids": input_ids, "attention_mask": attention} |
| if "token_type_ids" in self.input_names: |
| feed["token_type_ids"] = type_ids |
| logits = self.session.run(None, feed)[0] |
| return np.asarray(logits, dtype=np.float32).reshape(-1).tolist() |
|
|
|
|
| def main(): |
| import time |
| print(f"Loading {RERANK_REPO} ({RERANK_ONNX}) ...") |
| reranker = Reranker() |
| query = "powers of arrest without warrant" |
| docs = [ |
| "Arrest without warrant. A peace officer may arrest without warrant a " |
| "person who has committed a criminal offence.", |
| "Definitions. In this Act, fish means any fish and includes shellfish, " |
| "crustaceans and marine animals.", |
| "Importation. It is prohibited to import cannabis except as authorized " |
| "under this Act.", |
| ] |
| start = time.perf_counter() |
| scores = reranker.score(query, docs) |
| elapsed = (time.perf_counter() - start) * 1000 |
| print(f"\nQuery: {query!r}") |
| for doc, score in sorted(zip(docs, scores), key=lambda x: x[1], reverse=True): |
| print(f" {score:8.3f} {doc[:62]}") |
| print(f"\n{elapsed:.0f} ms for {len(docs)} documents") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|