seriffic's picture
Backend evolution: Phases 1-10 specialists + agentic FSM + Mellea + LiteLLM router
6a82282
"""Granite Embedding Reranker R2 (cross-encoder, 149 M).
Sits between the existing Granite Embedding 278 M retriever (top-K=20)
and the reconciler (top-3). Sidecar via sentence-transformers
CrossEncoder — vLLM `--task score` is explicitly out of scope.
License: Apache-2.0 (verified — `ibm-granite/granite-embedding-reranker-
english-r2`).
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
CACHE = Path(__file__).parent / ".cache"
CACHE.mkdir(exist_ok=True)
os.environ.setdefault("HF_HOME", str(CACHE / "hf"))
REPO = "ibm-granite/granite-embedding-reranker-english-r2"
@dataclass
class Ranking:
rank: int
score: float
text: str
def load_model():
from sentence_transformers import CrossEncoder
return CrossEncoder(REPO, cache_folder=str(CACHE / "hf"))
def rerank(model, query: str, candidates: list[str],
top_k: int = 3) -> list[Ranking]:
pairs = [[query, c] for c in candidates]
scores = model.predict(pairs)
indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
return [Ranking(rank=i + 1, score=float(s), text=candidates[idx])
for i, (idx, s) in enumerate(indexed[:top_k])]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--query", required=True)
ap.add_argument("--candidates-file", required=True,
help="One candidate paragraph per line")
ap.add_argument("--top-k", type=int, default=3)
args = ap.parse_args()
candidates = [ln.strip() for ln in
Path(args.candidates_file).read_text().splitlines()
if ln.strip()]
if not candidates:
print("No candidates provided", file=sys.stderr)
return 1
print("Loading reranker (~600 MB)…", file=sys.stderr)
t0 = time.time()
model = load_model()
print(f"reranker load: {time.time() - t0:.2f}s", file=sys.stderr)
t0 = time.time()
ranked = rerank(model, args.query, candidates, top_k=args.top_k)
print(f"rerank {len(candidates)} -> {args.top_k}: "
f"{time.time() - t0:.3f}s", file=sys.stderr)
print(json.dumps([asdict(r) for r in ranked], indent=2))
return 0
if __name__ == "__main__":
sys.exit(main())