File size: 2,325 Bytes
6a82282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Granite Embedding Reranker R2 (cross-encoder, 149 M).

Sits between the existing Granite Embedding 278 M retriever (top-K=20)
and the reconciler (top-3). Sidecar via sentence-transformers
CrossEncoder — vLLM `--task score` is explicitly out of scope.

License: Apache-2.0 (verified — `ibm-granite/granite-embedding-reranker-
english-r2`).
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path

CACHE = Path(__file__).parent / ".cache"
CACHE.mkdir(exist_ok=True)
os.environ.setdefault("HF_HOME", str(CACHE / "hf"))

REPO = "ibm-granite/granite-embedding-reranker-english-r2"


@dataclass
class Ranking:
    rank: int
    score: float
    text: str


def load_model():
    from sentence_transformers import CrossEncoder
    return CrossEncoder(REPO, cache_folder=str(CACHE / "hf"))


def rerank(model, query: str, candidates: list[str],
           top_k: int = 3) -> list[Ranking]:
    pairs = [[query, c] for c in candidates]
    scores = model.predict(pairs)
    indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
    return [Ranking(rank=i + 1, score=float(s), text=candidates[idx])
            for i, (idx, s) in enumerate(indexed[:top_k])]


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--query", required=True)
    ap.add_argument("--candidates-file", required=True,
                    help="One candidate paragraph per line")
    ap.add_argument("--top-k", type=int, default=3)
    args = ap.parse_args()

    candidates = [ln.strip() for ln in
                  Path(args.candidates_file).read_text().splitlines()
                  if ln.strip()]
    if not candidates:
        print("No candidates provided", file=sys.stderr)
        return 1

    print("Loading reranker (~600 MB)…", file=sys.stderr)
    t0 = time.time()
    model = load_model()
    print(f"reranker load: {time.time() - t0:.2f}s", file=sys.stderr)

    t0 = time.time()
    ranked = rerank(model, args.query, candidates, top_k=args.top_k)
    print(f"rerank {len(candidates)} -> {args.top_k}: "
          f"{time.time() - t0:.3f}s", file=sys.stderr)
    print(json.dumps([asdict(r) for r in ranked], indent=2))
    return 0


if __name__ == "__main__":
    sys.exit(main())