| """Granite Embedding Reranker R2 (cross-encoder, 149 M). |
| |
| Sits between the existing Granite Embedding 278 M retriever (top-K=20) |
| and the reconciler (top-3). Sidecar via sentence-transformers |
| CrossEncoder — vLLM `--task score` is explicitly out of scope. |
| |
| License: Apache-2.0 (verified — `ibm-granite/granite-embedding-reranker- |
| english-r2`). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| CACHE = Path(__file__).parent / ".cache" |
| CACHE.mkdir(exist_ok=True) |
| os.environ.setdefault("HF_HOME", str(CACHE / "hf")) |
|
|
| REPO = "ibm-granite/granite-embedding-reranker-english-r2" |
|
|
|
|
| @dataclass |
| class Ranking: |
| rank: int |
| score: float |
| text: str |
|
|
|
|
| def load_model(): |
| from sentence_transformers import CrossEncoder |
| return CrossEncoder(REPO, cache_folder=str(CACHE / "hf")) |
|
|
|
|
| def rerank(model, query: str, candidates: list[str], |
| top_k: int = 3) -> list[Ranking]: |
| pairs = [[query, c] for c in candidates] |
| scores = model.predict(pairs) |
| indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) |
| return [Ranking(rank=i + 1, score=float(s), text=candidates[idx]) |
| for i, (idx, s) in enumerate(indexed[:top_k])] |
|
|
|
|
| def main() -> int: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--query", required=True) |
| ap.add_argument("--candidates-file", required=True, |
| help="One candidate paragraph per line") |
| ap.add_argument("--top-k", type=int, default=3) |
| args = ap.parse_args() |
|
|
| candidates = [ln.strip() for ln in |
| Path(args.candidates_file).read_text().splitlines() |
| if ln.strip()] |
| if not candidates: |
| print("No candidates provided", file=sys.stderr) |
| return 1 |
|
|
| print("Loading reranker (~600 MB)…", file=sys.stderr) |
| t0 = time.time() |
| model = load_model() |
| print(f"reranker load: {time.time() - t0:.2f}s", file=sys.stderr) |
|
|
| t0 = time.time() |
| ranked = rerank(model, args.query, candidates, top_k=args.top_k) |
| print(f"rerank {len(candidates)} -> {args.top_k}: " |
| f"{time.time() - t0:.3f}s", file=sys.stderr) |
| print(json.dumps([asdict(r) for r in ranked], indent=2)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|