File size: 5,890 Bytes
6a82282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Phase 3 end-to-end: existing app.rag retriever -> Granite reranker
-> top-3, then double-gate via reconciler call.

Demonstrates the rerank changing the top-3 order vs retriever-only on
a query that's known to be ambiguous (the corpus has paragraphs about
multiple flood mechanisms; the query specifically asks about pluvial
flooding in Queens).

Caveat: the existing `retrieve()` does at-most-1 chunk per doc. We
bypass that for the experiment by fetching with k=20 and only using
the retriever's similarity ranking, not its dedup. In production
integration the dedup would happen *after* the reranker, not before,
so we'd get the reranker-improved top-3 with at most 1 paragraph per
PDF.
"""

from __future__ import annotations

import argparse
import json
import sys
import time
from pathlib import Path

# Make app/ importable
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))

from rerank import load_model as load_reranker  # noqa: E402
from rerank import rerank

from experiments.shared import backends, trace_render  # noqa: E402

USER_PROMPT = (
    "Write a single sentence answering the user's query, citing the "
    "ranked source with [{cite}]. Use only the text in the provided "
    "document; if it doesn't address the query, say so."
)


def retriever_top_k(query: str, k: int = 20) -> list[dict]:
    """Return top-K retriever chunks WITHOUT the per-doc dedup."""
    import numpy as np

    from app.rag import _ensure_index
    idx = _ensure_index()
    if idx["embs"] is None:
        return []
    qv = idx["model"].encode([query], convert_to_numpy=True,
                             normalize_embeddings=True).astype("float32")
    sims = (idx["embs"] @ qv.T).ravel()
    order = np.argsort(-sims)[:k]
    return [
        {"doc_id": idx["chunks"][i].doc_id,
         "text": idx["chunks"][i].text,
         "retriever_score": float(sims[i]),
         "rank": rk + 1}
        for rk, i in enumerate(order)
    ]


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--query", required=True)
    ap.add_argument("--top-k-retriever", type=int, default=20)
    ap.add_argument("--top-k-reranker",  type=int, default=3)
    ap.add_argument("--vllm-base-url", required=True)
    ap.add_argument("--vllm-api-key", required=True)
    args = ap.parse_args()

    print(trace_render.banner(f"Phase 3 double-gate · reranker · {args.query}"))

    print("Warming retriever (Granite Embedding 278M)…")
    t0 = time.time()
    retr_top = retriever_top_k(args.query, k=args.top_k_retriever)
    print(f"retriever: {time.time() - t0:.2f}s "
          f"({len(retr_top)} candidates)")

    print("\nRetriever top-3 (BEFORE rerank):")
    for r in retr_top[:3]:
        print(f"  rank {r['rank']:>2} score={r['retriever_score']:.3f}  "
              f"doc={r['doc_id']}  text={r['text'][:80]}…")

    print("\nLoading reranker…")
    t0 = time.time()
    reranker = load_reranker()
    print(f"reranker load: {time.time() - t0:.2f}s")

    t0 = time.time()
    candidates = [r["text"] for r in retr_top]
    ranked = rerank(reranker, args.query, candidates,
                    top_k=args.top_k_reranker)
    print(f"rerank ({len(retr_top)} -> {args.top_k_reranker}): "
          f"{time.time() - t0:.3f}s")

    print("\nReranker top-3 (AFTER rerank):")
    for r in ranked:
        # Find the original retriever info to compare ranks
        orig = next((x for x in retr_top if x["text"] == r.text), None)
        orig_rank = orig["rank"] if orig else "?"
        print(f"  rank {r.rank}  score={r.score:.3f}  "
              f"(was retriever rank {orig_rank})  "
              f"doc={orig['doc_id'] if orig else '?'}  "
              f"text={r.text[:80]}…")

    # Build a single-doc citation for the top-1 reranker hit and run
    # the reconciler. doc_id slug = the source PDF's doc_id.
    top1 = ranked[0]
    top1_orig = next((x for x in retr_top if x["text"] == top1.text), None)
    cite_id = (top1_orig or {}).get("doc_id", "rag_top")
    doc = {"role": f"document {cite_id}", "content": top1.text}

    results = []
    for backend_name, kwargs in [
        ("ollama", dict(backend="ollama")),
        ("vllm",   dict(backend="vllm",
                        base_url=args.vllm_base_url,
                        api_key=args.vllm_api_key)),
    ]:
        backends.configure(**kwargs)
        t0 = time.time()
        try:
            messages = [
                doc,
                {"role": "system", "content": USER_PROMPT.format(cite=cite_id)},
                {"role": "user", "content": args.query},
            ]
            resp = backends.chat(model="granite4.1:8b", messages=messages,
                                 options={"temperature": 0,
                                          "num_predict": 200,
                                          "num_ctx": 4096})
            r = {"backend": backend_name,
                 "info": backends.backend_info(),
                 "elapsed_s": round(time.time() - t0, 2),
                 "content": resp["message"]["content"].strip()}
        except Exception as e:
            r = {"backend": backend_name,
                 "error": f"{type(e).__name__}: {e}"}
        results.append(r)
        print(trace_render.banner(
            f"{backend_name}  ({r.get('elapsed_s', '-')}s)  "
            f"hw={r.get('info', {}).get('hardware', '?')}"))
        print(r.get("content", r.get("error")))

    out = Path(__file__).parent / ".cache" / "double_gate_rerank.json"
    out.parent.mkdir(exist_ok=True)
    out.write_text(json.dumps({
        "query": args.query,
        "retriever_top": retr_top,
        "reranker_top": [{"rank": r.rank, "score": r.score, "text": r.text}
                         for r in ranked],
        "results": results,
    }, indent=2, default=str))
    print(f"\nwrote {out}")
    return 0


if __name__ == "__main__":
    sys.exit(main())