File size: 10,783 Bytes
bc00192 6a82282 bc00192 6a82282 bc00192 abcf7cd bc00192 6a82282 bc00192 6a82282 bc00192 abcf7cd bc00192 6a82282 bc00192 6a82282 bc00192 6a82282 bc00192 6a82282 bc00192 6a82282 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | """Granite Embedding 278M RAG over the NYC flood-resilience policy corpus.
Specialists this powers:
step_rag β for any query (geo + intent), retrieve top-k relevant
policy paragraphs from HMP/NPCC4/DEP/MTA/NYCHA/Comptroller
and emit them as <document id="rag_*"> blocks.
We chunk page-by-page with a soft target of ~600 chars per chunk, embed
once at startup, and store a numpy matrix + FAISS L2 index in memory.
The index is small (~1k chunks across 5 PDFs).
"""
from __future__ import annotations
import logging
import os
import re
from dataclasses import dataclass
from pathlib import Path
import numpy as np
log = logging.getLogger("riprap.rag")
CORPUS_DIR = Path(__file__).resolve().parent.parent / "corpus"
EMBED_MODEL_NAME = "ibm-granite/granite-embedding-278m-multilingual"
CORPUS_META = {
"dep_wastewater_2013.pdf": {
"doc_id": "rag_dep_2013",
"title": "NYC DEP Wastewater Resiliency Plan (2013)",
"citation": "NYC DEP Wastewater Resiliency Plan, 2013",
},
"nycha_lessons.pdf": {
"doc_id": "rag_nycha",
"title": "Flood Resilience at NYCHA β Lessons Learned",
"citation": "NYCHA, Flood Resilience: Lessons Learned",
},
"coned_22_e_0222.pdf": {
"doc_id": "rag_coned",
"title": "Con Edison Climate Change Resilience Plan (2023, Case 22-E-0222)",
"citation": "Con Edison Climate Change Resilience Plan (2023, NY PSC Case 22-E-0222)",
},
"mta_resilience_2025.pdf": {
"doc_id": "rag_mta",
"title": "MTA Climate Resilience Roadmap (October 2025 update)",
"citation": "MTA Climate Resilience Roadmap, October 2025 update",
},
"comptroller_rain_2024.pdf": {
"doc_id": "rag_comptroller",
"title": "NYC Comptroller β Is NYC Ready for Rain? (2024)",
"citation": "NYC Comptroller, \"Is New York City Ready for Rain?\" (2024)",
},
}
@dataclass
class Chunk:
text: str
file: str
page: int
doc_id: str
title: str
citation: str
def _chunks_from_pdf(path: Path, target_chars: int = 700) -> list[Chunk]:
import pypdf
meta = CORPUS_META.get(path.name, {
"doc_id": f"rag_{path.stem}",
"title": path.stem,
"citation": path.stem,
})
out: list[Chunk] = []
try:
reader = pypdf.PdfReader(str(path))
except Exception as e:
log.warning("pdf load failed for %s: %s", path.name, e)
return out
for i, page in enumerate(reader.pages):
try:
txt = page.extract_text() or ""
except Exception:
txt = ""
txt = re.sub(r"\s+", " ", txt).strip()
if len(txt) < 80:
continue
# split into ~target_chars chunks at sentence boundaries
sentences = re.split(r"(?<=[.!?])\s+", txt)
buf = ""
for s in sentences:
if len(buf) + len(s) + 1 <= target_chars or not buf:
buf = (buf + " " + s).strip() if buf else s
else:
out.append(Chunk(text=buf, file=path.name, page=i + 1,
doc_id=meta["doc_id"], title=meta["title"],
citation=meta["citation"]))
buf = s
if buf:
out.append(Chunk(text=buf, file=path.name, page=i + 1,
doc_id=meta["doc_id"], title=meta["title"],
citation=meta["citation"]))
return out
_INDEX: dict | None = None
_RERANKER = None # lazy CrossEncoder
# Reranker switch: when "1", retrieve() over-fetches K*5 candidates without
# the per-doc dedup, scores them via the Granite Embedding Reranker R2
# cross-encoder, then dedups to K. Falls back to the baseline ranker when
# disabled. See experiments/03_granite_reranker/RESULTS.md for the
# reasoning behind inverting dedup vs rerank.
_RERANKER_ENABLE = os.environ.get("RIPRAP_RERANKER_ENABLE", "").lower() in ("1", "true", "yes")
_RERANKER_MODEL_NAME = os.environ.get(
"RIPRAP_RERANKER_MODEL",
"ibm-granite/granite-embedding-reranker-english-r2",
)
def _ensure_index():
global _INDEX
if _INDEX is not None:
return _INDEX
chunks: list[Chunk] = []
for f in sorted(CORPUS_DIR.glob("*.pdf")):
log.info("rag: chunking %s", f.name)
chunks.extend(_chunks_from_pdf(f))
log.info("rag: %d chunks across %d files",
len(chunks), len(set(c.file for c in chunks)))
if not chunks:
_INDEX = {"chunks": [], "embs": None, "model": None}
return _INDEX
texts = [c.text for c in chunks]
log.info("rag: embedding %d chunks", len(texts))
# v0.4.5 β try the MI300X service first. Avoids loading
# sentence-transformers + the granite-embedding weights on a
# cpu-basic surface (HF Space). Falls back to local on
# RemoteUnreachable so dev laptops keep working with no env.
embs = None
model = None
try:
from app import inference as _inf
if _inf.remote_enabled():
log.info("rag: encoding via remote MI300X")
remote = _inf.granite_embed(texts, timeout=120.0)
if remote.get("ok"):
embs = np.asarray(remote["vectors"], dtype="float32")
# Per-query encodes will also route through remote;
# `model` stays None and `retrieve()` checks for it.
except _inf.RemoteUnreachable as e:
log.info("rag: remote unreachable (%s); local fallback", e)
except Exception:
log.exception("rag: remote encode failed; local fallback")
if embs is None:
from sentence_transformers import SentenceTransformer
log.info("rag: loading %s (local fallback)", EMBED_MODEL_NAME)
model = SentenceTransformer(EMBED_MODEL_NAME)
embs = model.encode(texts, batch_size=32, show_progress_bar=False,
convert_to_numpy=True, normalize_embeddings=True)
embs = embs.astype("float32")
_INDEX = {"chunks": chunks, "embs": embs, "model": model}
log.info("rag: index ready (%s)", embs.shape)
return _INDEX
def _ensure_reranker():
"""Lazy-load the cross-encoder. Returns None if disabled or load fails;
callers fall back to the baseline ranker silently."""
global _RERANKER
if not _RERANKER_ENABLE:
return None
if _RERANKER is not None:
return _RERANKER
try:
from sentence_transformers import CrossEncoder
log.info("rag: loading reranker %s", _RERANKER_MODEL_NAME)
_RERANKER = CrossEncoder(_RERANKER_MODEL_NAME)
log.info("rag: reranker ready")
except Exception:
log.exception("rag: reranker load failed; falling back to baseline")
_RERANKER = False # sentinel: don't retry every call
return _RERANKER or None
def warm():
_ensure_index()
_ensure_reranker()
def retrieve(query: str, k: int = 4, min_score: float = 0.30) -> list[dict]:
idx = _ensure_index()
if idx["embs"] is None or not idx["chunks"]:
return []
# v0.4.5 β encode query via remote when corpus was embedded remotely.
# `_ensure_index` leaves `model = None` when it took the remote
# path, so this branch handles both:
# - model present β local SentenceTransformer.encode (fast, in-mem)
# - model is None β POST to MI300X, fallback to a one-shot local
# SentenceTransformer load if remote is down.
if idx["model"] is not None:
qv = idx["model"].encode([query], convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
else:
qv = None
try:
from app import inference as _inf
if _inf.remote_enabled():
remote = _inf.granite_embed([query])
if remote.get("ok"):
qv = np.asarray(remote["vectors"], dtype="float32")
except _inf.RemoteUnreachable as e:
log.info("rag: per-query encode remote unreachable (%s)", e)
if qv is None:
from sentence_transformers import SentenceTransformer
log.info("rag: cold-loading %s for per-query encode (remote down)",
EMBED_MODEL_NAME)
local = SentenceTransformer(EMBED_MODEL_NAME)
qv = local.encode([query], convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
# Cache so subsequent queries don't re-load
idx["model"] = local
sims = (idx["embs"] @ qv.T).ravel()
reranker = _ensure_reranker()
if reranker is not None:
# Over-fetch K*5 candidates (no per-doc dedup yet), rerank, then
# dedup to K. This keeps high-relevance chunks alive long enough
# for the cross-encoder to see them β the legacy path's
# dedup-before-rank threw them away.
cand_n = min(len(idx["chunks"]), max(k * 5, 20))
top_idx = np.argsort(-sims)[:cand_n]
candidates = [(int(i), idx["chunks"][int(i)],
float(sims[int(i)])) for i in top_idx
if float(sims[int(i)]) >= min_score]
if not candidates:
return []
pairs = [[query, c.text] for _, c, _ in candidates]
scores = reranker.predict(pairs)
ranked = sorted(zip(candidates, scores, strict=True),
key=lambda x: float(x[1]), reverse=True)
out: list[dict] = []
seen_per_doc: dict[str, int] = {}
for (_i, c, retr_score), rerank_score in ranked:
if seen_per_doc.get(c.doc_id, 0) >= 1:
continue
seen_per_doc[c.doc_id] = 1
out.append({
"doc_id": c.doc_id,
"title": c.title,
"citation": c.citation,
"file": c.file,
"page": c.page,
"text": c.text,
"score": float(rerank_score),
"retriever_score": retr_score,
})
if len(out) >= k:
break
return out
# Baseline ranker (unchanged behaviour when reranker disabled)
top = np.argsort(-sims)[:k * 3]
out2: list[dict] = []
seen_per_doc2: dict[str, int] = {}
for i in top:
if sims[i] < min_score:
continue
c = idx["chunks"][i]
if seen_per_doc2.get(c.doc_id, 0) >= 1:
continue
seen_per_doc2[c.doc_id] = 1
out2.append({
"doc_id": c.doc_id,
"title": c.title,
"citation": c.citation,
"file": c.file,
"page": c.page,
"text": c.text,
"score": float(sims[i]),
})
if len(out2) >= k:
break
return out2
|