File size: 5,890 Bytes
6a82282 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """Phase 3 end-to-end: existing app.rag retriever -> Granite reranker
-> top-3, then double-gate via reconciler call.
Demonstrates the rerank changing the top-3 order vs retriever-only on
a query that's known to be ambiguous (the corpus has paragraphs about
multiple flood mechanisms; the query specifically asks about pluvial
flooding in Queens).
Caveat: the existing `retrieve()` does at-most-1 chunk per doc. We
bypass that for the experiment by fetching with k=20 and only using
the retriever's similarity ranking, not its dedup. In production
integration the dedup would happen *after* the reranker, not before,
so we'd get the reranker-improved top-3 with at most 1 paragraph per
PDF.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
# Make app/ importable
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from rerank import load_model as load_reranker # noqa: E402
from rerank import rerank
from experiments.shared import backends, trace_render # noqa: E402
USER_PROMPT = (
"Write a single sentence answering the user's query, citing the "
"ranked source with [{cite}]. Use only the text in the provided "
"document; if it doesn't address the query, say so."
)
def retriever_top_k(query: str, k: int = 20) -> list[dict]:
"""Return top-K retriever chunks WITHOUT the per-doc dedup."""
import numpy as np
from app.rag import _ensure_index
idx = _ensure_index()
if idx["embs"] is None:
return []
qv = idx["model"].encode([query], convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
sims = (idx["embs"] @ qv.T).ravel()
order = np.argsort(-sims)[:k]
return [
{"doc_id": idx["chunks"][i].doc_id,
"text": idx["chunks"][i].text,
"retriever_score": float(sims[i]),
"rank": rk + 1}
for rk, i in enumerate(order)
]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--query", required=True)
ap.add_argument("--top-k-retriever", type=int, default=20)
ap.add_argument("--top-k-reranker", type=int, default=3)
ap.add_argument("--vllm-base-url", required=True)
ap.add_argument("--vllm-api-key", required=True)
args = ap.parse_args()
print(trace_render.banner(f"Phase 3 double-gate · reranker · {args.query}"))
print("Warming retriever (Granite Embedding 278M)…")
t0 = time.time()
retr_top = retriever_top_k(args.query, k=args.top_k_retriever)
print(f"retriever: {time.time() - t0:.2f}s "
f"({len(retr_top)} candidates)")
print("\nRetriever top-3 (BEFORE rerank):")
for r in retr_top[:3]:
print(f" rank {r['rank']:>2} score={r['retriever_score']:.3f} "
f"doc={r['doc_id']} text={r['text'][:80]}…")
print("\nLoading reranker…")
t0 = time.time()
reranker = load_reranker()
print(f"reranker load: {time.time() - t0:.2f}s")
t0 = time.time()
candidates = [r["text"] for r in retr_top]
ranked = rerank(reranker, args.query, candidates,
top_k=args.top_k_reranker)
print(f"rerank ({len(retr_top)} -> {args.top_k_reranker}): "
f"{time.time() - t0:.3f}s")
print("\nReranker top-3 (AFTER rerank):")
for r in ranked:
# Find the original retriever info to compare ranks
orig = next((x for x in retr_top if x["text"] == r.text), None)
orig_rank = orig["rank"] if orig else "?"
print(f" rank {r.rank} score={r.score:.3f} "
f"(was retriever rank {orig_rank}) "
f"doc={orig['doc_id'] if orig else '?'} "
f"text={r.text[:80]}…")
# Build a single-doc citation for the top-1 reranker hit and run
# the reconciler. doc_id slug = the source PDF's doc_id.
top1 = ranked[0]
top1_orig = next((x for x in retr_top if x["text"] == top1.text), None)
cite_id = (top1_orig or {}).get("doc_id", "rag_top")
doc = {"role": f"document {cite_id}", "content": top1.text}
results = []
for backend_name, kwargs in [
("ollama", dict(backend="ollama")),
("vllm", dict(backend="vllm",
base_url=args.vllm_base_url,
api_key=args.vllm_api_key)),
]:
backends.configure(**kwargs)
t0 = time.time()
try:
messages = [
doc,
{"role": "system", "content": USER_PROMPT.format(cite=cite_id)},
{"role": "user", "content": args.query},
]
resp = backends.chat(model="granite4.1:8b", messages=messages,
options={"temperature": 0,
"num_predict": 200,
"num_ctx": 4096})
r = {"backend": backend_name,
"info": backends.backend_info(),
"elapsed_s": round(time.time() - t0, 2),
"content": resp["message"]["content"].strip()}
except Exception as e:
r = {"backend": backend_name,
"error": f"{type(e).__name__}: {e}"}
results.append(r)
print(trace_render.banner(
f"{backend_name} ({r.get('elapsed_s', '-')}s) "
f"hw={r.get('info', {}).get('hardware', '?')}"))
print(r.get("content", r.get("error")))
out = Path(__file__).parent / ".cache" / "double_gate_rerank.json"
out.parent.mkdir(exist_ok=True)
out.write_text(json.dumps({
"query": args.query,
"retriever_top": retr_top,
"reranker_top": [{"rank": r.rank, "score": r.score, "text": r.text}
for r in ranked],
"results": results,
}, indent=2, default=str))
print(f"\nwrote {out}")
return 0
if __name__ == "__main__":
sys.exit(main())
|